"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "6112b1c6442aaf7affd2b0676a1cd4eee30c45cf"
Unverified Commit b53bc55b authored by Vishal Burman's avatar Vishal Burman Committed by GitHub
Browse files

Fix for making student ProphetNet for Seq2Seq Distillation (#12130)

* make_student.py: fix to make student ProphetNet

* reformat
parent b76850a8
......@@ -118,12 +118,18 @@ def create_student_by_copying_alternating_layers(
d = teacher_d
init_kwargs.update({"encoder_layers": e, "decoder_layers": d})
except AttributeError: # T5
teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_decoder_layers
if hasattr(teacher.config, "num_encoder_layers"):
teacher_e, teacher_d = teacher.config.num_encoder_layers, teacher.config.num_decoder_layers
else:
teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_decoder_layers
if e is None:
e = teacher_e
if d is None:
d = teacher_d
init_kwargs.update({"num_layers": e, "num_decoder_layers": d})
if hasattr(teacher.config, "num_encoder_layers"):
init_kwargs.update({"num_encoder_layers": e, "num_decoder_layers": d})
else:
init_kwargs.update({"num_layers": e, "num_decoder_layers": d})
# Kwargs to instantiate student: teacher kwargs with updated layer numbers + **extra_config_kwargs
init_kwargs.update(extra_config_kwargs)
......@@ -150,8 +156,14 @@ def create_student_by_copying_alternating_layers(
d_layers_to_copy: List[int] = pick_layers_to_copy(d, teacher_d)
try:
copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy)
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy)
if hasattr(
teacher, "prophetnet"
): # For ProphetNet, student.model.encoder.layers is called student.prophetnet.encoder.layers
copy_layers(teacher.prophetnet.encoder.layers, student.prophetnet.encoder.layers, e_layers_to_copy)
copy_layers(teacher.prophetnet.decoder.layers, student.prophetnet.decoder.layers, d_layers_to_copy)
else:
copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy)
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy)
except AttributeError: # For t5, student.model.encoder.layers is called student.encoder.block
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
copy_layers(teacher.decoder.block, student.decoder.block, d_layers_to_copy)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment