Commit ca88e8b4 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 370538762
parent 4c2ba498
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
"""Progressive distillation for MobileBERT student model.""" """Progressive distillation for MobileBERT student model."""
from typing import List, Optional
from absl import logging from absl import logging
import dataclasses import dataclasses
import orbit import orbit
...@@ -46,6 +48,14 @@ class LayerWiseDistillConfig(base_config.Config): ...@@ -46,6 +48,14 @@ class LayerWiseDistillConfig(base_config.Config):
attention_distill_factor: float = 1.0 attention_distill_factor: float = 1.0
if_freeze_previous_layers: bool = False if_freeze_previous_layers: bool = False
# The ids of teacher layers that will be mapped to the student model.
# For example, if you want to compress a 24 layer teacher to a 6 layer
# student, you can set it to [3, 7, 11, 15, 19, 23] (the index starts from 0).
# If `None`, we assume teacher and student have the same number of layers,
# and each layer of teacher model will be mapped to student's corresponding
# layer.
transfer_teacher_layers: Optional[List[int]] = None
@dataclasses.dataclass @dataclasses.dataclass
class PretrainDistillConfig(base_config.Config): class PretrainDistillConfig(base_config.Config):
...@@ -120,6 +130,23 @@ class BertDistillationTask(policies.ProgressivePolicy, base_task.Task): ...@@ -120,6 +130,23 @@ class BertDistillationTask(policies.ProgressivePolicy, base_task.Task):
self._the_only_train_dataset = None self._the_only_train_dataset = None
self._the_only_eval_dataset = None self._the_only_eval_dataset = None
layer_wise_config = self._progressive_config.layer_wise_distill_config
transfer_teacher_layers = layer_wise_config.transfer_teacher_layers
num_teacher_layers = (
self._task_config.teacher_model.encoder.mobilebert.num_blocks)
num_student_layers = (
self._task_config.student_model.encoder.mobilebert.num_blocks)
if transfer_teacher_layers and len(
transfer_teacher_layers) != num_student_layers:
raise ValueError('The number of `transfer_teacher_layers` %s does not '
'match the number of student layers. %d' %
(transfer_teacher_layers, num_student_layers))
if not transfer_teacher_layers and (num_teacher_layers !=
num_student_layers):
raise ValueError('`transfer_teacher_layers` is not specified, and the '
'number of teacher layers does not match '
'the number of student layers.')
ratio = progressive.pretrain_distill_config.distill_ground_truth_ratio ratio = progressive.pretrain_distill_config.distill_ground_truth_ratio
if ratio < 0 or ratio > 1: if ratio < 0 or ratio > 1:
raise ValueError('distill_ground_truth_ratio has to be within [0, 1].') raise ValueError('distill_ground_truth_ratio has to be within [0, 1].')
...@@ -169,7 +196,7 @@ class BertDistillationTask(policies.ProgressivePolicy, base_task.Task): ...@@ -169,7 +196,7 @@ class BertDistillationTask(policies.ProgressivePolicy, base_task.Task):
# override policies.ProgressivePolicy # override policies.ProgressivePolicy
def num_stages(self): def num_stages(self):
# One stage for each layer, plus additional stage for pre-training # One stage for each layer, plus additional stage for pre-training
return self._task_config.teacher_model.encoder.mobilebert.num_blocks + 1 return self._task_config.student_model.encoder.mobilebert.num_blocks + 1
# override policies.ProgressivePolicy # override policies.ProgressivePolicy
def num_steps(self, stage_id) -> int: def num_steps(self, stage_id) -> int:
...@@ -247,9 +274,16 @@ class BertDistillationTask(policies.ProgressivePolicy, base_task.Task): ...@@ -247,9 +274,16 @@ class BertDistillationTask(policies.ProgressivePolicy, base_task.Task):
encoder=student_encoder, target_layer_id=stage_id) encoder=student_encoder, target_layer_id=stage_id)
student_output_feature, student_attention_score = student_sub_encoder( student_output_feature, student_attention_score = student_sub_encoder(
inputs) inputs)
if layer_wise_config.transfer_teacher_layers:
teacher_layer_id = layer_wise_config.transfer_teacher_layers[stage_id]
else:
teacher_layer_id = stage_id
teacher_sub_encoder = build_sub_encoder( teacher_sub_encoder = build_sub_encoder(
encoder=self._teacher_pretrainer.encoder_network, encoder=self._teacher_pretrainer.encoder_network,
target_layer_id=stage_id) target_layer_id=teacher_layer_id)
teacher_output_feature, teacher_attention_score = teacher_sub_encoder( teacher_output_feature, teacher_attention_score = teacher_sub_encoder(
inputs) inputs)
......
...@@ -16,7 +16,9 @@ ...@@ -16,7 +16,9 @@
import os import os
from absl import logging from absl import logging
from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.modeling import optimization from official.modeling import optimization
from official.modeling import tf_utils from official.modeling import tf_utils
...@@ -29,18 +31,17 @@ from official.nlp.modeling import models ...@@ -29,18 +31,17 @@ from official.nlp.modeling import models
from official.nlp.projects.mobilebert import distillation from official.nlp.projects.mobilebert import distillation
class DistillationTest(tf.test.TestCase): class DistillationTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def prepare_config(self, teacher_block_num, student_block_num,
super(DistillationTest, self).setUp() transfer_teacher_layers):
# using small model for testing # using small model for testing
self.model_block_num = 2 task_config = distillation.BertDistillationTaskConfig(
self.task_config = distillation.BertDistillationTaskConfig(
teacher_model=bert.PretrainerConfig( teacher_model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig( encoder=encoders.EncoderConfig(
type='mobilebert', type='mobilebert',
mobilebert=encoders.MobileBertEncoderConfig( mobilebert=encoders.MobileBertEncoderConfig(
num_blocks=self.model_block_num)), num_blocks=teacher_block_num)),
cls_heads=[ cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=256, inner_dim=256,
...@@ -53,7 +54,7 @@ class DistillationTest(tf.test.TestCase): ...@@ -53,7 +54,7 @@ class DistillationTest(tf.test.TestCase):
encoder=encoders.EncoderConfig( encoder=encoders.EncoderConfig(
type='mobilebert', type='mobilebert',
mobilebert=encoders.MobileBertEncoderConfig( mobilebert=encoders.MobileBertEncoderConfig(
num_blocks=self.model_block_num)), num_blocks=student_block_num)),
cls_heads=[ cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=256, inner_dim=256,
...@@ -75,6 +76,8 @@ class DistillationTest(tf.test.TestCase): ...@@ -75,6 +76,8 @@ class DistillationTest(tf.test.TestCase):
# set only 1 step for each stage # set only 1 step for each stage
progressive_config = distillation.BertDistillationProgressiveConfig() progressive_config = distillation.BertDistillationProgressiveConfig()
progressive_config.layer_wise_distill_config.transfer_teacher_layers = (
transfer_teacher_layers)
progressive_config.layer_wise_distill_config.num_steps = 1 progressive_config.layer_wise_distill_config.num_steps = 1
progressive_config.pretrain_distill_config.num_steps = 1 progressive_config.pretrain_distill_config.num_steps = 1
...@@ -96,16 +99,15 @@ class DistillationTest(tf.test.TestCase): ...@@ -96,16 +99,15 @@ class DistillationTest(tf.test.TestCase):
type='linear', type='linear',
linear=optimization.LinearWarmupConfig(warmup_learning_rate=0))) linear=optimization.LinearWarmupConfig(warmup_learning_rate=0)))
self.exp_config = cfg.ExperimentConfig( exp_config = cfg.ExperimentConfig(
task=self.task_config, task=task_config,
trainer=prog_trainer_lib.ProgressiveTrainerConfig( trainer=prog_trainer_lib.ProgressiveTrainerConfig(
progressive=progressive_config, progressive=progressive_config,
optimizer_config=optimization_config)) optimizer_config=optimization_config))
# Create a teacher model checkpoint. # Create a teacher model checkpoint.
teacher_encoder = encoders.build_encoder( teacher_encoder = encoders.build_encoder(task_config.teacher_model.encoder)
self.task_config.teacher_model.encoder) pretrainer_config = task_config.teacher_model
pretrainer_config = self.task_config.teacher_model
if pretrainer_config.cls_heads: if pretrainer_config.cls_heads:
teacher_cls_heads = [ teacher_cls_heads = [
layers.ClassificationHead(**cfg.as_dict()) layers.ClassificationHead(**cfg.as_dict())
...@@ -131,14 +133,20 @@ class DistillationTest(tf.test.TestCase): ...@@ -131,14 +133,20 @@ class DistillationTest(tf.test.TestCase):
**teacher_pretrainer.checkpoint_items) **teacher_pretrainer.checkpoint_items)
teacher_ckpt_path = os.path.join(self.get_temp_dir(), 'teacher_model.ckpt') teacher_ckpt_path = os.path.join(self.get_temp_dir(), 'teacher_model.ckpt')
teacher_pretrainer_ckpt.save(teacher_ckpt_path) teacher_pretrainer_ckpt.save(teacher_ckpt_path)
self.task_config.teacher_model_init_checkpoint = self.get_temp_dir() exp_config.task.teacher_model_init_checkpoint = self.get_temp_dir()
return exp_config
def test_task(self): @parameterized.parameters((2, 2, None), (4, 2, [1, 3]))
def test_task(self, teacher_block_num, student_block_num,
transfer_teacher_layers):
exp_config = self.prepare_config(teacher_block_num, student_block_num,
transfer_teacher_layers)
bert_distillation_task = distillation.BertDistillationTask( bert_distillation_task = distillation.BertDistillationTask(
strategy=tf.distribute.get_strategy(), strategy=tf.distribute.get_strategy(),
progressive=self.exp_config.trainer.progressive, progressive=exp_config.trainer.progressive,
optimizer_config=self.exp_config.trainer.optimizer_config, optimizer_config=exp_config.trainer.optimizer_config,
task_config=self.task_config) task_config=exp_config.task)
metrics = bert_distillation_task.build_metrics() metrics = bert_distillation_task.build_metrics()
train_dataset = bert_distillation_task.get_train_dataset(stage_id=0) train_dataset = bert_distillation_task.get_train_dataset(stage_id=0)
train_iterator = iter(train_dataset) train_iterator = iter(train_dataset)
...@@ -148,7 +156,7 @@ class DistillationTest(tf.test.TestCase): ...@@ -148,7 +156,7 @@ class DistillationTest(tf.test.TestCase):
optimizer = tf.keras.optimizers.SGD(lr=0.1) optimizer = tf.keras.optimizers.SGD(lr=0.1)
# test train/val step for all stages, including the last pretraining stage # test train/val step for all stages, including the last pretraining stage
for stage in range(self.model_block_num + 1): for stage in range(student_block_num + 1):
step = stage step = stage
bert_distillation_task.update_pt_stage(step) bert_distillation_task.update_pt_stage(step)
model = bert_distillation_task.get_model(stage, None) model = bert_distillation_task.get_model(stage, None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment