Commit 0c803498 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 370538762
parent 7e88ce3e
......@@ -13,6 +13,8 @@
# limitations under the License.
"""Progressive distillation for MobileBERT student model."""
from typing import List, Optional
from absl import logging
import dataclasses
import orbit
......@@ -46,6 +48,14 @@ class LayerWiseDistillConfig(base_config.Config):
attention_distill_factor: float = 1.0
if_freeze_previous_layers: bool = False
# The ids of teacher layers that will be mapped to the student model.
# For example, if you want to compress a 24 layer teacher to a 6 layer
# student, you can set it to [3, 7, 11, 15, 19, 23] (the index starts from 0).
# If `None`, we assume teacher and student have the same number of layers,
# and each layer of teacher model will be mapped to student's corresponding
# layer.
transfer_teacher_layers: Optional[List[int]] = None
@dataclasses.dataclass
class PretrainDistillConfig(base_config.Config):
......@@ -120,6 +130,23 @@ class BertDistillationTask(policies.ProgressivePolicy, base_task.Task):
self._the_only_train_dataset = None
self._the_only_eval_dataset = None
layer_wise_config = self._progressive_config.layer_wise_distill_config
transfer_teacher_layers = layer_wise_config.transfer_teacher_layers
num_teacher_layers = (
self._task_config.teacher_model.encoder.mobilebert.num_blocks)
num_student_layers = (
self._task_config.student_model.encoder.mobilebert.num_blocks)
if transfer_teacher_layers and len(
transfer_teacher_layers) != num_student_layers:
raise ValueError('The number of `transfer_teacher_layers` %s does not '
'match the number of student layers. %d' %
(transfer_teacher_layers, num_student_layers))
if not transfer_teacher_layers and (num_teacher_layers !=
num_student_layers):
raise ValueError('`transfer_teacher_layers` is not specified, and the '
'number of teacher layers does not match '
'the number of student layers.')
ratio = progressive.pretrain_distill_config.distill_ground_truth_ratio
if ratio < 0 or ratio > 1:
raise ValueError('distill_ground_truth_ratio has to be within [0, 1].')
......@@ -169,7 +196,7 @@ class BertDistillationTask(policies.ProgressivePolicy, base_task.Task):
# override policies.ProgressivePolicy
def num_stages(self):
# One stage for each layer, plus additional stage for pre-training
return self._task_config.teacher_model.encoder.mobilebert.num_blocks + 1
return self._task_config.student_model.encoder.mobilebert.num_blocks + 1
# override policies.ProgressivePolicy
def num_steps(self, stage_id) -> int:
......@@ -247,9 +274,16 @@ class BertDistillationTask(policies.ProgressivePolicy, base_task.Task):
encoder=student_encoder, target_layer_id=stage_id)
student_output_feature, student_attention_score = student_sub_encoder(
inputs)
if layer_wise_config.transfer_teacher_layers:
teacher_layer_id = layer_wise_config.transfer_teacher_layers[stage_id]
else:
teacher_layer_id = stage_id
teacher_sub_encoder = build_sub_encoder(
encoder=self._teacher_pretrainer.encoder_network,
target_layer_id=stage_id)
target_layer_id=teacher_layer_id)
teacher_output_feature, teacher_attention_score = teacher_sub_encoder(
inputs)
......
......@@ -16,7 +16,9 @@
import os
from absl import logging
from absl.testing import parameterized
import tensorflow as tf
from official.core import config_definitions as cfg
from official.modeling import optimization
from official.modeling import tf_utils
......@@ -29,18 +31,17 @@ from official.nlp.modeling import models
from official.nlp.projects.mobilebert import distillation
class DistillationTest(tf.test.TestCase):
class DistillationTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(DistillationTest, self).setUp()
def prepare_config(self, teacher_block_num, student_block_num,
transfer_teacher_layers):
# using small model for testing
self.model_block_num = 2
self.task_config = distillation.BertDistillationTaskConfig(
task_config = distillation.BertDistillationTaskConfig(
teacher_model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
type='mobilebert',
mobilebert=encoders.MobileBertEncoderConfig(
num_blocks=self.model_block_num)),
num_blocks=teacher_block_num)),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=256,
......@@ -53,7 +54,7 @@ class DistillationTest(tf.test.TestCase):
encoder=encoders.EncoderConfig(
type='mobilebert',
mobilebert=encoders.MobileBertEncoderConfig(
num_blocks=self.model_block_num)),
num_blocks=student_block_num)),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=256,
......@@ -75,6 +76,8 @@ class DistillationTest(tf.test.TestCase):
# set only 1 step for each stage
progressive_config = distillation.BertDistillationProgressiveConfig()
progressive_config.layer_wise_distill_config.transfer_teacher_layers = (
transfer_teacher_layers)
progressive_config.layer_wise_distill_config.num_steps = 1
progressive_config.pretrain_distill_config.num_steps = 1
......@@ -96,16 +99,15 @@ class DistillationTest(tf.test.TestCase):
type='linear',
linear=optimization.LinearWarmupConfig(warmup_learning_rate=0)))
self.exp_config = cfg.ExperimentConfig(
task=self.task_config,
exp_config = cfg.ExperimentConfig(
task=task_config,
trainer=prog_trainer_lib.ProgressiveTrainerConfig(
progressive=progressive_config,
optimizer_config=optimization_config))
# Create a teacher model checkpoint.
teacher_encoder = encoders.build_encoder(
self.task_config.teacher_model.encoder)
pretrainer_config = self.task_config.teacher_model
teacher_encoder = encoders.build_encoder(task_config.teacher_model.encoder)
pretrainer_config = task_config.teacher_model
if pretrainer_config.cls_heads:
teacher_cls_heads = [
layers.ClassificationHead(**cfg.as_dict())
......@@ -131,14 +133,20 @@ class DistillationTest(tf.test.TestCase):
**teacher_pretrainer.checkpoint_items)
teacher_ckpt_path = os.path.join(self.get_temp_dir(), 'teacher_model.ckpt')
teacher_pretrainer_ckpt.save(teacher_ckpt_path)
self.task_config.teacher_model_init_checkpoint = self.get_temp_dir()
exp_config.task.teacher_model_init_checkpoint = self.get_temp_dir()
return exp_config
def test_task(self):
@parameterized.parameters((2, 2, None), (4, 2, [1, 3]))
def test_task(self, teacher_block_num, student_block_num,
transfer_teacher_layers):
exp_config = self.prepare_config(teacher_block_num, student_block_num,
transfer_teacher_layers)
bert_distillation_task = distillation.BertDistillationTask(
strategy=tf.distribute.get_strategy(),
progressive=self.exp_config.trainer.progressive,
optimizer_config=self.exp_config.trainer.optimizer_config,
task_config=self.task_config)
progressive=exp_config.trainer.progressive,
optimizer_config=exp_config.trainer.optimizer_config,
task_config=exp_config.task)
metrics = bert_distillation_task.build_metrics()
train_dataset = bert_distillation_task.get_train_dataset(stage_id=0)
train_iterator = iter(train_dataset)
......@@ -148,7 +156,7 @@ class DistillationTest(tf.test.TestCase):
optimizer = tf.keras.optimizers.SGD(lr=0.1)
# test train/val step for all stages, including the last pretraining stage
for stage in range(self.model_block_num + 1):
for stage in range(student_block_num + 1):
step = stage
bert_distillation_task.update_pt_stage(step)
model = bert_distillation_task.get_model(stage, None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment