Unverified Commit e6604247 authored by Anthony Susevski's avatar Anthony Susevski Committed by GitHub
Browse files

fixed typos (issue 27919) (#27920)



* fixed typos (issue 27919)

* Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent e5079b0b
...@@ -61,8 +61,8 @@ import torch.nn.functional as F ...@@ -61,8 +61,8 @@ import torch.nn.functional as F
class ImageDistilTrainer(Trainer): class ImageDistilTrainer(Trainer):
def __init__(self, *args, teacher_model=None, **kwargs): def __init__(self, teacher_model=None, student_model=None, temperature=None, lambda_param=None, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(model=student_model, *args, **kwargs)
self.teacher = teacher_model self.teacher = teacher_model
self.student = student_model self.student = student_model
self.loss_function = nn.KLDivLoss(reduction="batchmean") self.loss_function = nn.KLDivLoss(reduction="batchmean")
...@@ -164,7 +164,7 @@ trainer = ImageDistilTrainer( ...@@ -164,7 +164,7 @@ trainer = ImageDistilTrainer(
train_dataset=processed_datasets["train"], train_dataset=processed_datasets["train"],
eval_dataset=processed_datasets["validation"], eval_dataset=processed_datasets["validation"],
data_collator=data_collator, data_collator=data_collator,
tokenizer=teacher_extractor, tokenizer=teacher_processor,
compute_metrics=compute_metrics, compute_metrics=compute_metrics,
temperature=5, temperature=5,
lambda_param=0.5 lambda_param=0.5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment