Unverified Commit ae1cb4ec authored by Shichao Sun's avatar Shichao Sun Committed by GitHub
Browse files

[s2s/distill] hparams.tokenizer_name = hparams.teacher (#8382)

parent aec51e56
...@@ -45,6 +45,7 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -45,6 +45,7 @@ class BartSummarizationDistiller(SummarizationModule):
) )
if hparams.length_penalty != -1: if hparams.length_penalty != -1:
student.config.length_penalty = hparams.length_penalty student.config.length_penalty = hparams.length_penalty
hparams.tokenizer_name = hparams.teacher # Use teacher's tokenizer
super().__init__(hparams, model=student, config=student.config) super().__init__(hparams, model=student, config=student.config)
model_type = student.config.model_type model_type = student.config.model_type
self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int] self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment