Unverified Commit b53bc55b authored by Vishal Burman's avatar Vishal Burman Committed by GitHub
Browse files

Fix for making student ProphetNet for Seq2Seq Distillation (#12130)

* make_student.py: fix to make student ProphetNet

* reformat
parent b76850a8
...@@ -118,12 +118,18 @@ def create_student_by_copying_alternating_layers( ...@@ -118,12 +118,18 @@ def create_student_by_copying_alternating_layers(
d = teacher_d d = teacher_d
init_kwargs.update({"encoder_layers": e, "decoder_layers": d}) init_kwargs.update({"encoder_layers": e, "decoder_layers": d})
except AttributeError: # T5 except AttributeError: # T5
teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_decoder_layers if hasattr(teacher.config, "num_encoder_layers"):
teacher_e, teacher_d = teacher.config.num_encoder_layers, teacher.config.num_decoder_layers
else:
teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_decoder_layers
if e is None: if e is None:
e = teacher_e e = teacher_e
if d is None: if d is None:
d = teacher_d d = teacher_d
init_kwargs.update({"num_layers": e, "num_decoder_layers": d}) if hasattr(teacher.config, "num_encoder_layers"):
init_kwargs.update({"num_encoder_layers": e, "num_decoder_layers": d})
else:
init_kwargs.update({"num_layers": e, "num_decoder_layers": d})
# Kwargs to instantiate student: teacher kwargs with updated layer numbers + **extra_config_kwargs # Kwargs to instantiate student: teacher kwargs with updated layer numbers + **extra_config_kwargs
init_kwargs.update(extra_config_kwargs) init_kwargs.update(extra_config_kwargs)
...@@ -150,8 +156,14 @@ def create_student_by_copying_alternating_layers( ...@@ -150,8 +156,14 @@ def create_student_by_copying_alternating_layers(
d_layers_to_copy: List[int] = pick_layers_to_copy(d, teacher_d) d_layers_to_copy: List[int] = pick_layers_to_copy(d, teacher_d)
try: try:
copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy) if hasattr(
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy) teacher, "prophetnet"
): # For ProphetNet, student.model.encoder.layers is called student.prophetnet.encoder.layers
copy_layers(teacher.prophetnet.encoder.layers, student.prophetnet.encoder.layers, e_layers_to_copy)
copy_layers(teacher.prophetnet.decoder.layers, student.prophetnet.decoder.layers, d_layers_to_copy)
else:
copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy)
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy)
except AttributeError: # For t5, student.model.encoder.layers is called student.encoder.block except AttributeError: # For t5, student.model.encoder.layers is called student.encoder.block
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy) copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
copy_layers(teacher.decoder.block, student.decoder.block, d_layers_to_copy) copy_layers(teacher.decoder.block, student.decoder.block, d_layers_to_copy)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment