Unverified Commit b53bc55b authored by Vishal Burman's avatar Vishal Burman Committed by GitHub
Browse files

Fix for making student ProphetNet for Seq2Seq Distillation (#12130)

* make_student.py: fix to make student ProphetNet

* reformat
parent b76850a8
...@@ -118,11 +118,17 @@ def create_student_by_copying_alternating_layers( ...@@ -118,11 +118,17 @@ def create_student_by_copying_alternating_layers(
d = teacher_d d = teacher_d
init_kwargs.update({"encoder_layers": e, "decoder_layers": d}) init_kwargs.update({"encoder_layers": e, "decoder_layers": d})
except AttributeError: # T5 except AttributeError: # T5
if hasattr(teacher.config, "num_encoder_layers"):
teacher_e, teacher_d = teacher.config.num_encoder_layers, teacher.config.num_decoder_layers
else:
teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_decoder_layers teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_decoder_layers
if e is None: if e is None:
e = teacher_e e = teacher_e
if d is None: if d is None:
d = teacher_d d = teacher_d
if hasattr(teacher.config, "num_encoder_layers"):
init_kwargs.update({"num_encoder_layers": e, "num_decoder_layers": d})
else:
init_kwargs.update({"num_layers": e, "num_decoder_layers": d}) init_kwargs.update({"num_layers": e, "num_decoder_layers": d})
# Kwargs to instantiate student: teacher kwargs with updated layer numbers + **extra_config_kwargs # Kwargs to instantiate student: teacher kwargs with updated layer numbers + **extra_config_kwargs
...@@ -150,6 +156,12 @@ def create_student_by_copying_alternating_layers( ...@@ -150,6 +156,12 @@ def create_student_by_copying_alternating_layers(
d_layers_to_copy: List[int] = pick_layers_to_copy(d, teacher_d) d_layers_to_copy: List[int] = pick_layers_to_copy(d, teacher_d)
try: try:
if hasattr(
teacher, "prophetnet"
): # For ProphetNet, student.model.encoder.layers is called student.prophetnet.encoder.layers
copy_layers(teacher.prophetnet.encoder.layers, student.prophetnet.encoder.layers, e_layers_to_copy)
copy_layers(teacher.prophetnet.decoder.layers, student.prophetnet.decoder.layers, d_layers_to_copy)
else:
copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy) copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy)
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy) copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy)
except AttributeError: # For t5, student.model.encoder.layers is called student.encoder.block except AttributeError: # For t5, student.model.encoder.layers is called student.encoder.block
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment