"tools/python/vscode:/vscode.git/clone" did not exist on "a16515eac7593c0da5561d2d6eea7bafd1b44b3c"
Commit 682d18ef authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[nlp][progressive] Opensource progressive tasks.

PiperOrigin-RevId: 365242134
parent a283424a
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Masked language task with progressive training."""
from typing import List
# Import libraries
from absl import logging
import dataclasses
import orbit
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling import optimization
from official.modeling.hyperparams import base_config
from official.modeling.progressive import policies
from official.nlp.tasks import masked_lm
@dataclasses.dataclass
class StackingStageConfig(base_config.Config):
num_layers: int = 0
num_steps: int = 0
warmup_steps: int = 10000
initial_learning_rate: float = 1e-4
end_learning_rate: float = 0.0
decay_steps: int = 1000000
@dataclasses.dataclass
class ProgMaskedLMConfig(masked_lm.MaskedLMConfig):
"""The progressive model config."""
optimizer_config: optimization.OptimizationConfig = (
optimization.OptimizationConfig(
optimizer=optimization.OptimizerConfig(type='adamw'),
learning_rate=optimization.LrConfig(type='polynomial'),
warmup=optimization.WarmupConfig(type='polynomial'),
)
)
stage_list: List[StackingStageConfig] = dataclasses.field(
default_factory=lambda: [ # pylint: disable=g-long-lambda
StackingStageConfig(num_layers=3,
num_steps=112500,
warmup_steps=10000,
initial_learning_rate=1e-4,
end_learning_rate=1e-4,
decay_steps=112500),
StackingStageConfig(num_layers=6,
num_steps=112500,
warmup_steps=10000,
initial_learning_rate=1e-4,
end_learning_rate=1e-4,
decay_steps=112500),
StackingStageConfig(num_layers=12,
num_steps=450000,
warmup_steps=10000,
initial_learning_rate=1e-4,
end_learning_rate=0.0,
decay_steps=450000)])
@task_factory.register_task_cls(ProgMaskedLMConfig)
class ProgressiveMaskedLM(policies.ProgressivePolicy, masked_lm.MaskedLMTask):
"""Masked Language Model that supports progressive training.
Inherate from the MaskedLmTask class to build model datasets etc.
"""
def __init__(self, params: cfg.TaskConfig, logging_dir: str = None):
masked_lm.MaskedLMTask.__init__(
self, params=params, logging_dir=logging_dir)
self._model_config = params.model
self._optimizer_config = params.optimizer_config
self._the_only_train_dataset = None
self._the_only_eval_dataset = None
policies.ProgressivePolicy.__init__(self)
# Override
def num_stages(self):
return len(self.task_config.stage_list)
# Override
def num_steps(self, stage_id):
return self.task_config.stage_list[stage_id].num_steps
# Override
def get_model(self, stage_id, old_model=None):
"""Build model for each stage."""
num_layers = self.task_config.stage_list[stage_id].num_layers
encoder_type = self._model_config.encoder.type
params = self._model_config.replace(
encoder={encoder_type: {
'num_layers': num_layers
}})
model = self.build_model(params)
# Run the model once, to make sure that all layers are built.
# Otherwise, not all weights will be copied.
_ = model(model.inputs)
if stage_id > 0 and old_model is not None:
logging.info('Stage %d copying weights.', stage_id)
self._copy_weights_to_new_model(old_model=old_model,
new_model=model)
return model
# Override
def get_optimizer(self, stage_id):
"""Build optimizer for each stage."""
params = self._optimizer_config.replace(
learning_rate={
'polynomial':
{'decay_steps':
self.task_config.stage_list[
stage_id].decay_steps,
'initial_learning_rate':
self.task_config.stage_list[
stage_id].initial_learning_rate,
'end_learning_rate':
self.task_config.stage_list[
stage_id].end_learning_rate,
'power': 1,
'cycle': False,
}
},
warmup={
'polynomial':
{'warmup_steps':
self.task_config.stage_list[stage_id].warmup_steps,
'power': 1,
}
}
)
opt_factory = optimization.OptimizerFactory(params)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
return optimizer
# overrides policies.ProgressivePolicy
def get_train_dataset(self, stage_id):
del stage_id
if self._the_only_train_dataset is None:
strategy = tf.distribute.get_strategy()
self._the_only_train_dataset = orbit.utils.make_distributed_dataset(
strategy,
self.build_inputs,
self.task_config.train_data)
return self._the_only_train_dataset
# overrides policies.ProgressivePolicy
def get_eval_dataset(self, stage_id):
del stage_id
if self._the_only_eval_dataset is None:
strategy = tf.distribute.get_strategy()
self._the_only_eval_dataset = orbit.utils.make_distributed_dataset(
strategy,
self.build_inputs,
self.task_config.validation_data)
return self._the_only_eval_dataset
def _copy_weights_to_new_model(self, old_model, new_model):
"""Copy model weights from the previous stage to the next.
Args:
old_model: nlp.modeling.models.bert_pretrainer.BertPretrainerV2. Model of
the previous stage.
new_model: nlp.modeling.models.bert_pretrainer.BertPretrainerV2. Model of
the next stage.
"""
# Copy weights of the embedding layers.
# pylint: disable=protected-access
# When using `encoder_scaffold`, there may be `_embedding_network`.
if hasattr(new_model.encoder_network, '_embedding_network') and hasattr(
old_model.encoder_network, '_embedding_network') and (
new_model.encoder_network._embedding_network is not None):
new_model.encoder_network._embedding_network.set_weights(
old_model.encoder_network._embedding_network.get_weights())
else:
new_model.encoder_network._embedding_layer.set_weights(
old_model.encoder_network._embedding_layer.get_weights())
new_model.encoder_network._position_embedding_layer.set_weights(
old_model.encoder_network._position_embedding_layer.get_weights())
new_model.encoder_network._type_embedding_layer.set_weights(
old_model.encoder_network._type_embedding_layer.get_weights())
new_model.encoder_network._embedding_norm_layer.set_weights(
old_model.encoder_network._embedding_norm_layer.get_weights())
if hasattr(new_model.encoder_network, '_embedding_projection') and hasattr(
old_model.encoder_network, '_embedding_projection'):
if old_model.encoder_network._embedding_projection is not None:
new_model.encoder_network._embedding_projection.set_weights(
old_model.encoder_network._embedding_projection.get_weights())
# pylint: enable=protected-access
# Copy weights of the transformer layers.
# The model can be EncoderScaffold or TransformerEncoder.
if hasattr(old_model.encoder_network, 'hidden_layers'):
old_layer_group = old_model.encoder_network.hidden_layers
elif hasattr(old_model.encoder_network, 'transformer_layers'):
old_layer_group = old_model.encoder_network.transformer_layers
else:
raise ValueError('Unrecognized encoder network: {}'.format(
old_model.encoder_network))
if hasattr(new_model.encoder_network, 'hidden_layers'):
new_layer_group = new_model.encoder_network.hidden_layers
elif hasattr(new_model.encoder_network, 'transformer_layers'):
new_layer_group = new_model.encoder_network.transformer_layers
else:
raise ValueError('Unrecognized encoder network: {}'.format(
new_model.encoder_network))
for new_layer_idx in range(len(new_layer_group)):
old_layer_idx = new_layer_idx % len(old_layer_group)
new_layer_group[new_layer_idx].set_weights(
old_layer_group[old_layer_idx].get_weights())
if old_layer_idx != new_layer_idx:
if hasattr(new_layer_group[new_layer_idx], 'reset_rezero'):
# Reset ReZero's alpha to 0.
new_layer_group[new_layer_idx].reset_rezero()
# Copy weights of the final layer norm (if needed).
# pylint: disable=protected-access
if hasattr(new_model.encoder_network, '_output_layer_norm') and hasattr(
old_model.encoder_network, '_output_layer_norm'):
new_model.encoder_network._output_layer_norm.set_weights(
old_model.encoder_network._output_layer_norm.get_weights())
# pylint: enable=protected-access
# Copy weights of the pooler layer.
new_model.encoder_network.pooler_layer.set_weights(
old_model.encoder_network.pooler_layer.get_weights())
# Copy weights of the classification head.
for idx in range(len(new_model.classification_heads)):
new_model.classification_heads[idx].set_weights(
old_model.classification_heads[idx].get_weights())
# Copy weights of the masked_lm layer.
new_model.masked_lm.set_weights(old_model.masked_lm.get_weights())
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for google.nlp.progressive_masked_lm."""
# Import libraries
from absl.testing import parameterized
import gin
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import config_definitions as cfg
from official.modeling.progressive import trainer as prog_trainer_lib
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import pretrain_dataloader
from official.nlp.tasks import progressive_masked_lm
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],)
class ProgressiveMaskedLMTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(ProgressiveMaskedLMTest, self).setUp()
self.task_config = progressive_masked_lm.ProgMaskedLMConfig(
model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=2)),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
]),
train_data=pretrain_dataloader.BertPretrainDataConfig(
input_path="dummy",
max_predictions_per_seq=20,
seq_length=128,
global_batch_size=1),
validation_data=pretrain_dataloader.BertPretrainDataConfig(
input_path="dummy",
max_predictions_per_seq=20,
seq_length=128,
global_batch_size=1),
stage_list=[
progressive_masked_lm.StackingStageConfig(
num_layers=1, num_steps=4),
progressive_masked_lm.StackingStageConfig(
num_layers=2, num_steps=8),
],
)
self.exp_config = cfg.ExperimentConfig(
task=self.task_config,
trainer=prog_trainer_lib.ProgressiveTrainerConfig())
@combinations.generate(all_strategy_combinations())
def test_num_stages(self, distribution):
with distribution.scope():
prog_masked_lm = progressive_masked_lm.ProgressiveMaskedLM(
self.task_config)
self.assertEqual(prog_masked_lm.num_stages(), 2)
self.assertEqual(prog_masked_lm.num_steps(0), 4)
self.assertEqual(prog_masked_lm.num_steps(1), 8)
@combinations.generate(all_strategy_combinations())
def test_weight_copying(self, distribution):
with distribution.scope():
prog_masked_lm = progressive_masked_lm.ProgressiveMaskedLM(
self.task_config)
old_model = prog_masked_lm.get_model(stage_id=0)
for w in old_model.trainable_weights:
w.assign(tf.zeros_like(w) + 0.12345)
new_model = prog_masked_lm.get_model(stage_id=1, old_model=old_model)
for w in new_model.trainable_weights:
self.assertAllClose(w, tf.zeros_like(w) + 0.12345)
gin.parse_config_files_and_bindings(
None, "encoders.build_encoder.encoder_cls = @EncoderScaffold")
with distribution.scope():
prog_masked_lm = progressive_masked_lm.ProgressiveMaskedLM(
self.task_config)
old_model = prog_masked_lm.get_model(stage_id=0)
for w in old_model.trainable_weights:
w.assign(tf.zeros_like(w) + 0.12345)
new_model = prog_masked_lm.get_model(stage_id=1, old_model=old_model)
for w in new_model.trainable_weights:
self.assertAllClose(w, tf.zeros_like(w) + 0.12345)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Translation task with progressive training."""
from typing import List
# Import libraries
from absl import logging
import dataclasses
import orbit
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling import optimization
from official.modeling.hyperparams import base_config
from official.modeling.progressive import policies
from official.nlp.modeling import models
from official.nlp.tasks import translation
@dataclasses.dataclass
class StackingStageConfig(base_config.Config):
num_encoder_layers: int = 0
num_decoder_layers: int = 0
num_steps: int = 0
warmup_steps: int = 10000
initial_learning_rate: float = 0.0625
power: float = -0.5
@dataclasses.dataclass
class ProgTranslationConfig(translation.TranslationConfig):
"""The progressive model config."""
model: translation.ModelConfig = translation.ModelConfig(
encoder=translation.EncDecoder(
num_attention_heads=16, intermediate_size=4096),
decoder=translation.EncDecoder(
num_attention_heads=16, intermediate_size=4096),
embedding_width=1024,
padded_decode=True,
decode_max_length=100)
optimizer_config: optimization.OptimizationConfig = (
optimization.OptimizationConfig({
'optimizer': {
'type': 'adam',
'adam': {
'beta_2': 0.997,
'epsilon': 1e-9,
},
},
'learning_rate': {
'type': 'power',
'power': {
'initial_learning_rate': 0.0625,
'power': -0.5,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 16000,
'warmup_learning_rate': 0.0
}
}
}))
stage_list: List[StackingStageConfig] = dataclasses.field(
default_factory=lambda: [ # pylint: disable=g-long-lambda
StackingStageConfig(num_encoder_layers=3,
num_decoder_layers=3,
num_steps=20000,
warmup_steps=5000,
initial_learning_rate=0.0625),
StackingStageConfig(num_encoder_layers=6,
num_decoder_layers=6,
num_steps=20000,
warmup_steps=5000,
initial_learning_rate=0.0625),
StackingStageConfig(num_encoder_layers=12,
num_decoder_layers=12,
num_steps=100000,
warmup_steps=5000,
initial_learning_rate=0.0625)])
@task_factory.register_task_cls(ProgTranslationConfig)
class ProgressiveTranslationTask(policies.ProgressivePolicy,
translation.TranslationTask):
"""Masked Language Model that supports progressive training.
Inherate from the TranslationTask class to build model datasets etc.
"""
def __init__(self, params: cfg.TaskConfig, logging_dir: str = None):
translation.TranslationTask.__init__(
self, params=params, logging_dir=logging_dir)
self._model_config = params.model
self._optimizer_config = params.optimizer_config
self._the_only_train_dataset = None
self._the_only_eval_dataset = None
policies.ProgressivePolicy.__init__(self)
# Override
def num_stages(self):
return len(self.task_config.stage_list)
# Override
def num_steps(self, stage_id):
return self.task_config.stage_list[stage_id].num_steps
# Override
def get_model(self, stage_id, old_model=None):
"""Build model for each stage."""
num_encoder_layers = (
self.task_config.stage_list[stage_id].num_encoder_layers)
num_decoder_layers = (
self.task_config.stage_list[stage_id].num_decoder_layers)
params = self._model_config.replace(
encoder={'num_layers': num_encoder_layers},
decoder={'num_layers': num_decoder_layers})
model = self.build_model(params)
# Run the model once, to make sure that all layers are built.
# Otherwise, not all weights will be copied.
inputs = next(tf.nest.map_structure(
iter, self.build_inputs(self.task_config.train_data)))
model(inputs, training=True)
if stage_id > 0 and old_model is not None:
logging.info('Stage %d copying weights.', stage_id)
self._copy_weights_to_new_model(old_model=old_model,
new_model=model)
return model
# Override
def build_model(self, params) -> tf.keras.Model:
"""Creates model architecture."""
model_cfg = params or self.task_config.model
encoder_kwargs = model_cfg.encoder.as_dict()
encoder_layer = models.TransformerEncoder(**encoder_kwargs)
decoder_kwargs = model_cfg.decoder.as_dict()
decoder_layer = models.TransformerDecoder(**decoder_kwargs)
return models.Seq2SeqTransformer(
vocab_size=self._vocab_size,
embedding_width=model_cfg.embedding_width,
dropout_rate=model_cfg.dropout_rate,
padded_decode=model_cfg.padded_decode,
decode_max_length=model_cfg.decode_max_length,
beam_size=model_cfg.beam_size,
alpha=model_cfg.alpha,
encoder_layer=encoder_layer,
decoder_layer=decoder_layer)
# Override
def get_optimizer(self, stage_id):
"""Build optimizer for each stage."""
params = self._optimizer_config.replace(
warmup={
'linear':
{'warmup_steps':
self.task_config.stage_list[stage_id].warmup_steps
},
},
learning_rate={
'power':
{'initial_learning_rate':
self.task_config.stage_list[stage_id].initial_learning_rate
},
},
)
opt_factory = optimization.OptimizerFactory(params)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
return optimizer
# overrides policies.ProgressivePolicy
def get_train_dataset(self, stage_id):
del stage_id
if self._the_only_train_dataset is None:
strategy = tf.distribute.get_strategy()
self._the_only_train_dataset = orbit.utils.make_distributed_dataset(
strategy,
self.build_inputs,
self.task_config.train_data)
return self._the_only_train_dataset
# overrides policies.ProgressivePolicy
def get_eval_dataset(self, stage_id):
del stage_id
if self._the_only_eval_dataset is None:
strategy = tf.distribute.get_strategy()
self._the_only_eval_dataset = orbit.utils.make_distributed_dataset(
strategy,
self.build_inputs,
self.task_config.validation_data)
return self._the_only_eval_dataset
def _copy_weights_to_new_model(self, old_model, new_model):
"""Copy model weights from the previous stage to the next.
Args:
old_model: nlp.modeling.models.bert_pretrainer.BertPretrainerV2. Model of
the previous stage.
new_model: nlp.modeling.models.bert_pretrainer.BertPretrainerV2. Model of
the next stage.
"""
new_model.embedding_lookup.set_weights(
old_model.embedding_lookup.get_weights())
new_model.position_embedding.set_weights(
old_model.position_embedding.get_weights())
new_model.encoder_layer.output_normalization.set_weights(
old_model.encoder_layer.output_normalization.get_weights())
new_model.decoder_layer.output_normalization.set_weights(
old_model.decoder_layer.output_normalization.get_weights())
old_layer_group = old_model.encoder_layer.encoder_layers
new_layer_group = new_model.encoder_layer.encoder_layers
for new_layer_idx in range(len(new_layer_group)):
old_layer_idx = new_layer_idx % len(old_layer_group)
new_layer_group[new_layer_idx].set_weights(
old_layer_group[old_layer_idx].get_weights())
old_layer_group = old_model.decoder_layer.decoder_layers
new_layer_group = new_model.decoder_layer.decoder_layers
for new_layer_idx in range(len(new_layer_group)):
old_layer_idx = new_layer_idx % len(old_layer_group)
new_layer_group[new_layer_idx].set_weights(
old_layer_group[old_layer_idx].get_weights())
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for google.nlp.progressive_translation."""
import os
from absl.testing import parameterized
import tensorflow as tf
from sentencepiece import SentencePieceTrainer
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import config_definitions as cfg
from official.modeling.progressive import trainer as prog_trainer_lib
from official.nlp.data import wmt_dataloader
from official.nlp.tasks import progressive_translation
from official.nlp.tasks import translation
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],)
def _generate_line_file(filepath, lines):
with tf.io.gfile.GFile(filepath, "w") as f:
for l in lines:
f.write("{}\n".format(l))
def _generate_record_file(filepath, src_lines, tgt_lines):
writer = tf.io.TFRecordWriter(filepath)
for src, tgt in zip(src_lines, tgt_lines):
example = tf.train.Example(
features=tf.train.Features(
feature={
"en": tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[src.encode()])),
"reverse_en": tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[tgt.encode()])),
}))
writer.write(example.SerializeToString())
writer.close()
def _train_sentencepiece(input_path, vocab_size, model_path, eos_id=1):
argstr = " ".join([
f"--input={input_path}", f"--vocab_size={vocab_size}",
"--character_coverage=0.995",
f"--model_prefix={model_path}", "--model_type=bpe",
"--bos_id=-1", "--pad_id=0", f"--eos_id={eos_id}", "--unk_id=2"
])
SentencePieceTrainer.Train(argstr)
class ProgressiveTranslationTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(ProgressiveTranslationTest, self).setUp()
self._temp_dir = self.get_temp_dir()
src_lines = ["abc ede fg", "bbcd ef a g", "de f a a g"]
tgt_lines = ["dd cc a ef g", "bcd ef a g", "gef cd ba"]
self._record_input_path = os.path.join(self._temp_dir, "train.record")
_generate_record_file(self._record_input_path, src_lines, tgt_lines)
self._sentencepeice_input_path = os.path.join(self._temp_dir, "inputs.txt")
_generate_line_file(self._sentencepeice_input_path, src_lines + tgt_lines)
sentencepeice_model_prefix = os.path.join(self._temp_dir, "sp")
_train_sentencepiece(self._sentencepeice_input_path, 11,
sentencepeice_model_prefix)
self._sentencepeice_model_path = "{}.model".format(
sentencepeice_model_prefix)
encdecoder = translation.EncDecoder(
num_attention_heads=2, intermediate_size=8)
self.task_config = progressive_translation.ProgTranslationConfig(
model=translation.ModelConfig(
encoder=encdecoder,
decoder=encdecoder,
embedding_width=8,
padded_decode=True,
decode_max_length=100),
train_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path,
is_training=True,
global_batch_size=24,
static_batch=True,
src_lang="en",
tgt_lang="reverse_en",
max_seq_length=12),
validation_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path,
is_training=False,
global_batch_size=2,
static_batch=True,
src_lang="en",
tgt_lang="reverse_en",
max_seq_length=12),
sentencepiece_model_path=self._sentencepeice_model_path,
stage_list=[
progressive_translation.StackingStageConfig(
num_encoder_layers=1, num_decoder_layers=1, num_steps=4),
progressive_translation.StackingStageConfig(
num_encoder_layers=2, num_decoder_layers=1, num_steps=8),
],
)
self.exp_config = cfg.ExperimentConfig(
task=self.task_config,
trainer=prog_trainer_lib.ProgressiveTrainerConfig())
@combinations.generate(all_strategy_combinations())
def test_num_stages(self, distribution):
with distribution.scope():
prog_translation = progressive_translation.ProgressiveTranslationTask(
self.task_config)
self.assertEqual(prog_translation.num_stages(), 2)
self.assertEqual(prog_translation.num_steps(0), 4)
self.assertEqual(prog_translation.num_steps(1), 8)
@combinations.generate(all_strategy_combinations())
def test_weight_copying(self, distribution):
with distribution.scope():
prog_translation = progressive_translation.ProgressiveTranslationTask(
self.task_config)
old_model = prog_translation.get_model(stage_id=0)
for w in old_model.trainable_weights:
w.assign(tf.zeros_like(w) + 0.12345)
new_model = prog_translation.get_model(stage_id=1, old_model=old_model)
for w in new_model.trainable_weights:
self.assertAllClose(w, tf.zeros_like(w) + 0.12345)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment