Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ca88e8b4
Commit
ca88e8b4
authored
Apr 26, 2021
by
Chen Chen
Committed by
A. Unique TensorFlower
Apr 26, 2021
Browse files
Internal change
PiperOrigin-RevId: 370538762
parent
4c2ba498
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
62 additions
and
20 deletions
+62
-20
official/nlp/projects/mobilebert/distillation.py
official/nlp/projects/mobilebert/distillation.py
+36
-2
official/nlp/projects/mobilebert/distillation_test.py
official/nlp/projects/mobilebert/distillation_test.py
+26
-18
No files found.
official/nlp/projects/mobilebert/distillation.py
View file @
ca88e8b4
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
# limitations under the License.
# limitations under the License.
"""Progressive distillation for MobileBERT student model."""
"""Progressive distillation for MobileBERT student model."""
from
typing
import
List
,
Optional
from
absl
import
logging
from
absl
import
logging
import
dataclasses
import
dataclasses
import
orbit
import
orbit
...
@@ -46,6 +48,14 @@ class LayerWiseDistillConfig(base_config.Config):
...
@@ -46,6 +48,14 @@ class LayerWiseDistillConfig(base_config.Config):
attention_distill_factor
:
float
=
1.0
attention_distill_factor
:
float
=
1.0
if_freeze_previous_layers
:
bool
=
False
if_freeze_previous_layers
:
bool
=
False
# The ids of teacher layers that will be mapped to the student model.
# For example, if you want to compress a 24 layer teacher to a 6 layer
# student, you can set it to [3, 7, 11, 15, 19, 23] (the index starts from 0).
# If `None`, we assume teacher and student have the same number of layers,
# and each layer of teacher model will be mapped to student's corresponding
# layer.
transfer_teacher_layers
:
Optional
[
List
[
int
]]
=
None
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
PretrainDistillConfig
(
base_config
.
Config
):
class
PretrainDistillConfig
(
base_config
.
Config
):
...
@@ -120,6 +130,23 @@ class BertDistillationTask(policies.ProgressivePolicy, base_task.Task):
...
@@ -120,6 +130,23 @@ class BertDistillationTask(policies.ProgressivePolicy, base_task.Task):
self
.
_the_only_train_dataset
=
None
self
.
_the_only_train_dataset
=
None
self
.
_the_only_eval_dataset
=
None
self
.
_the_only_eval_dataset
=
None
layer_wise_config
=
self
.
_progressive_config
.
layer_wise_distill_config
transfer_teacher_layers
=
layer_wise_config
.
transfer_teacher_layers
num_teacher_layers
=
(
self
.
_task_config
.
teacher_model
.
encoder
.
mobilebert
.
num_blocks
)
num_student_layers
=
(
self
.
_task_config
.
student_model
.
encoder
.
mobilebert
.
num_blocks
)
if
transfer_teacher_layers
and
len
(
transfer_teacher_layers
)
!=
num_student_layers
:
raise
ValueError
(
'The number of `transfer_teacher_layers` %s does not '
'match the number of student layers. %d'
%
(
transfer_teacher_layers
,
num_student_layers
))
if
not
transfer_teacher_layers
and
(
num_teacher_layers
!=
num_student_layers
):
raise
ValueError
(
'`transfer_teacher_layers` is not specified, and the '
'number of teacher layers does not match '
'the number of student layers.'
)
ratio
=
progressive
.
pretrain_distill_config
.
distill_ground_truth_ratio
ratio
=
progressive
.
pretrain_distill_config
.
distill_ground_truth_ratio
if
ratio
<
0
or
ratio
>
1
:
if
ratio
<
0
or
ratio
>
1
:
raise
ValueError
(
'distill_ground_truth_ratio has to be within [0, 1].'
)
raise
ValueError
(
'distill_ground_truth_ratio has to be within [0, 1].'
)
...
@@ -169,7 +196,7 @@ class BertDistillationTask(policies.ProgressivePolicy, base_task.Task):
...
@@ -169,7 +196,7 @@ class BertDistillationTask(policies.ProgressivePolicy, base_task.Task):
# override policies.ProgressivePolicy
# override policies.ProgressivePolicy
def
num_stages
(
self
):
def
num_stages
(
self
):
# One stage for each layer, plus additional stage for pre-training
# One stage for each layer, plus additional stage for pre-training
return
self
.
_task_config
.
teacher
_model
.
encoder
.
mobilebert
.
num_blocks
+
1
return
self
.
_task_config
.
student
_model
.
encoder
.
mobilebert
.
num_blocks
+
1
# override policies.ProgressivePolicy
# override policies.ProgressivePolicy
def
num_steps
(
self
,
stage_id
)
->
int
:
def
num_steps
(
self
,
stage_id
)
->
int
:
...
@@ -247,9 +274,16 @@ class BertDistillationTask(policies.ProgressivePolicy, base_task.Task):
...
@@ -247,9 +274,16 @@ class BertDistillationTask(policies.ProgressivePolicy, base_task.Task):
encoder
=
student_encoder
,
target_layer_id
=
stage_id
)
encoder
=
student_encoder
,
target_layer_id
=
stage_id
)
student_output_feature
,
student_attention_score
=
student_sub_encoder
(
student_output_feature
,
student_attention_score
=
student_sub_encoder
(
inputs
)
inputs
)
if
layer_wise_config
.
transfer_teacher_layers
:
teacher_layer_id
=
layer_wise_config
.
transfer_teacher_layers
[
stage_id
]
else
:
teacher_layer_id
=
stage_id
teacher_sub_encoder
=
build_sub_encoder
(
teacher_sub_encoder
=
build_sub_encoder
(
encoder
=
self
.
_teacher_pretrainer
.
encoder_network
,
encoder
=
self
.
_teacher_pretrainer
.
encoder_network
,
target_layer_id
=
stage_id
)
target_layer_id
=
teacher_layer_id
)
teacher_output_feature
,
teacher_attention_score
=
teacher_sub_encoder
(
teacher_output_feature
,
teacher_attention_score
=
teacher_sub_encoder
(
inputs
)
inputs
)
...
...
official/nlp/projects/mobilebert/distillation_test.py
View file @
ca88e8b4
...
@@ -16,7 +16,9 @@
...
@@ -16,7 +16,9 @@
import
os
import
os
from
absl
import
logging
from
absl
import
logging
from
absl.testing
import
parameterized
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
config_definitions
as
cfg
from
official.modeling
import
optimization
from
official.modeling
import
optimization
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
...
@@ -29,18 +31,17 @@ from official.nlp.modeling import models
...
@@ -29,18 +31,17 @@ from official.nlp.modeling import models
from
official.nlp.projects.mobilebert
import
distillation
from
official.nlp.projects.mobilebert
import
distillation
class
DistillationTest
(
tf
.
test
.
TestCase
):
class
DistillationTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
def
prepare_config
(
self
,
teacher_block_num
,
student_block_num
,
super
(
DistillationTest
,
self
).
setUp
()
transfer_teacher_layers
):
# using small model for testing
# using small model for testing
self
.
model_block_num
=
2
task_config
=
distillation
.
BertDistillationTaskConfig
(
self
.
task_config
=
distillation
.
BertDistillationTaskConfig
(
teacher_model
=
bert
.
PretrainerConfig
(
teacher_model
=
bert
.
PretrainerConfig
(
encoder
=
encoders
.
EncoderConfig
(
encoder
=
encoders
.
EncoderConfig
(
type
=
'mobilebert'
,
type
=
'mobilebert'
,
mobilebert
=
encoders
.
MobileBertEncoderConfig
(
mobilebert
=
encoders
.
MobileBertEncoderConfig
(
num_blocks
=
self
.
model
_block_num
)),
num_blocks
=
teacher
_block_num
)),
cls_heads
=
[
cls_heads
=
[
bert
.
ClsHeadConfig
(
bert
.
ClsHeadConfig
(
inner_dim
=
256
,
inner_dim
=
256
,
...
@@ -53,7 +54,7 @@ class DistillationTest(tf.test.TestCase):
...
@@ -53,7 +54,7 @@ class DistillationTest(tf.test.TestCase):
encoder
=
encoders
.
EncoderConfig
(
encoder
=
encoders
.
EncoderConfig
(
type
=
'mobilebert'
,
type
=
'mobilebert'
,
mobilebert
=
encoders
.
MobileBertEncoderConfig
(
mobilebert
=
encoders
.
MobileBertEncoderConfig
(
num_blocks
=
s
elf
.
model
_block_num
)),
num_blocks
=
s
tudent
_block_num
)),
cls_heads
=
[
cls_heads
=
[
bert
.
ClsHeadConfig
(
bert
.
ClsHeadConfig
(
inner_dim
=
256
,
inner_dim
=
256
,
...
@@ -75,6 +76,8 @@ class DistillationTest(tf.test.TestCase):
...
@@ -75,6 +76,8 @@ class DistillationTest(tf.test.TestCase):
# set only 1 step for each stage
# set only 1 step for each stage
progressive_config
=
distillation
.
BertDistillationProgressiveConfig
()
progressive_config
=
distillation
.
BertDistillationProgressiveConfig
()
progressive_config
.
layer_wise_distill_config
.
transfer_teacher_layers
=
(
transfer_teacher_layers
)
progressive_config
.
layer_wise_distill_config
.
num_steps
=
1
progressive_config
.
layer_wise_distill_config
.
num_steps
=
1
progressive_config
.
pretrain_distill_config
.
num_steps
=
1
progressive_config
.
pretrain_distill_config
.
num_steps
=
1
...
@@ -96,16 +99,15 @@ class DistillationTest(tf.test.TestCase):
...
@@ -96,16 +99,15 @@ class DistillationTest(tf.test.TestCase):
type
=
'linear'
,
type
=
'linear'
,
linear
=
optimization
.
LinearWarmupConfig
(
warmup_learning_rate
=
0
)))
linear
=
optimization
.
LinearWarmupConfig
(
warmup_learning_rate
=
0
)))
self
.
exp_config
=
cfg
.
ExperimentConfig
(
exp_config
=
cfg
.
ExperimentConfig
(
task
=
self
.
task_config
,
task
=
task_config
,
trainer
=
prog_trainer_lib
.
ProgressiveTrainerConfig
(
trainer
=
prog_trainer_lib
.
ProgressiveTrainerConfig
(
progressive
=
progressive_config
,
progressive
=
progressive_config
,
optimizer_config
=
optimization_config
))
optimizer_config
=
optimization_config
))
# Create a teacher model checkpoint.
# Create a teacher model checkpoint.
teacher_encoder
=
encoders
.
build_encoder
(
teacher_encoder
=
encoders
.
build_encoder
(
task_config
.
teacher_model
.
encoder
)
self
.
task_config
.
teacher_model
.
encoder
)
pretrainer_config
=
task_config
.
teacher_model
pretrainer_config
=
self
.
task_config
.
teacher_model
if
pretrainer_config
.
cls_heads
:
if
pretrainer_config
.
cls_heads
:
teacher_cls_heads
=
[
teacher_cls_heads
=
[
layers
.
ClassificationHead
(
**
cfg
.
as_dict
())
layers
.
ClassificationHead
(
**
cfg
.
as_dict
())
...
@@ -131,14 +133,20 @@ class DistillationTest(tf.test.TestCase):
...
@@ -131,14 +133,20 @@ class DistillationTest(tf.test.TestCase):
**
teacher_pretrainer
.
checkpoint_items
)
**
teacher_pretrainer
.
checkpoint_items
)
teacher_ckpt_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'teacher_model.ckpt'
)
teacher_ckpt_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'teacher_model.ckpt'
)
teacher_pretrainer_ckpt
.
save
(
teacher_ckpt_path
)
teacher_pretrainer_ckpt
.
save
(
teacher_ckpt_path
)
self
.
task_config
.
teacher_model_init_checkpoint
=
self
.
get_temp_dir
()
exp_config
.
task
.
teacher_model_init_checkpoint
=
self
.
get_temp_dir
()
return
exp_config
def
test_task
(
self
):
@
parameterized
.
parameters
((
2
,
2
,
None
),
(
4
,
2
,
[
1
,
3
]))
def
test_task
(
self
,
teacher_block_num
,
student_block_num
,
transfer_teacher_layers
):
exp_config
=
self
.
prepare_config
(
teacher_block_num
,
student_block_num
,
transfer_teacher_layers
)
bert_distillation_task
=
distillation
.
BertDistillationTask
(
bert_distillation_task
=
distillation
.
BertDistillationTask
(
strategy
=
tf
.
distribute
.
get_strategy
(),
strategy
=
tf
.
distribute
.
get_strategy
(),
progressive
=
self
.
exp_config
.
trainer
.
progressive
,
progressive
=
exp_config
.
trainer
.
progressive
,
optimizer_config
=
self
.
exp_config
.
trainer
.
optimizer_config
,
optimizer_config
=
exp_config
.
trainer
.
optimizer_config
,
task_config
=
self
.
task
_config
)
task_config
=
exp
_config
.
task
)
metrics
=
bert_distillation_task
.
build_metrics
()
metrics
=
bert_distillation_task
.
build_metrics
()
train_dataset
=
bert_distillation_task
.
get_train_dataset
(
stage_id
=
0
)
train_dataset
=
bert_distillation_task
.
get_train_dataset
(
stage_id
=
0
)
train_iterator
=
iter
(
train_dataset
)
train_iterator
=
iter
(
train_dataset
)
...
@@ -148,7 +156,7 @@ class DistillationTest(tf.test.TestCase):
...
@@ -148,7 +156,7 @@ class DistillationTest(tf.test.TestCase):
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
# test train/val step for all stages, including the last pretraining stage
# test train/val step for all stages, including the last pretraining stage
for
stage
in
range
(
s
elf
.
model
_block_num
+
1
):
for
stage
in
range
(
s
tudent
_block_num
+
1
):
step
=
stage
step
=
stage
bert_distillation_task
.
update_pt_stage
(
step
)
bert_distillation_task
.
update_pt_stage
(
step
)
model
=
bert_distillation_task
.
get_model
(
stage
,
None
)
model
=
bert_distillation_task
.
get_model
(
stage
,
None
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment