import logging import torch import torch.nn.functional as F _logger = logging.getLogger(__name__) class KnowledgeDistill(): """ Knowledge Distillaion support while fine-tuning the compressed model Geoffrey Hinton, Oriol Vinyals, Jeff Dean "Distilling the Knowledge in a Neural Network" https://arxiv.org/abs/1503.02531 """ def __init__(self, teacher_model, kd_T=1): """ Parameters ---------- teacher_model : pytorch model the teacher_model for teaching the student model, it should be pretrained kd_T: float kd_T is the temperature parameter, when kd_T=1 we get the standard softmax function As kd_T grows, the probability distribution generated by the softmax function becomes softer """ self.teacher_model = teacher_model self.kd_T = kd_T def _get_kd_loss(self, data, student_out, teacher_out_preprocess=None): """ Parameters ---------- data : torch.Tensor the input training data student_out: torch.Tensor output of the student network teacher_out_preprocess: function a function for pre-processing teacher_model's output e.g. when teacher_out_preprocess=lambda x:x[0] extract teacher_model's output (tensor1, tensor2)->tensor1 Returns ------- torch.Tensor weighted distillation loss """ with torch.no_grad(): kd_out = self.teacher_model(data) if teacher_out_preprocess is not None: kd_out = teacher_out_preprocess(kd_out) assert type(kd_out) is torch.Tensor assert type(student_out) is torch.Tensor assert kd_out.shape == student_out.shape soft_log_out = F.log_softmax(student_out / self.kd_T, dim=1) soft_t = F.softmax(kd_out / self.kd_T, dim=1) loss_kd = F.kl_div(soft_log_out, soft_t.detach(), reduction='batchmean') return loss_kd def loss(self, data, student_out): """ Parameters ---------- data : torch.Tensor Input of the student model student_out : torch.Tensor Output of the student model Returns ------- torch.Tensor Weighted loss of student loss and distillation loss """ return self._get_kd_loss(data, student_out)