KDExample.rst 1.68 KB
Newer Older
qianyj's avatar
qianyj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
Knowledge Distillation on NNI
=============================

KnowledgeDistill
----------------

Knowledge Distillation (KD) is proposed in `Distilling the Knowledge in a Neural Network <https://arxiv.org/abs/1503.02531>`__\ ,  the compressed model is trained to mimic a pre-trained, larger model.  This training setting is also referred to as "teacher-student",  where the large model is the teacher and the small model is the student. KD is often used to fine-tune the pruned model.


.. image:: ../../img/distill.png
   :target: ../../img/distill.png
   :alt: 

Usage
^^^^^

PyTorch code

.. code-block:: python

      for batch_idx, (data, target) in enumerate(train_loader):
         data, target = data.to(device), target.to(device)
         optimizer.zero_grad()
         y_s = model_s(data)
         y_t = model_t(data)
         loss_cri = F.cross_entropy(y_s, target)

         # kd loss
         p_s = F.log_softmax(y_s/kd_T, dim=1)
         p_t = F.softmax(y_t/kd_T, dim=1)
         loss_kd = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]

         # total loss
         loss = loss_cir + loss_kd
         loss.backward()


The complete code for fine-tuning the pruned model can be found :githublink:`here <examples/model_compress/pruning/finetune_kd_torch.py>`

.. code-block:: python

      python finetune_kd_torch.py --model [model name] --teacher-model-dir [pretrained checkpoint path]  --student-model-dir [pruned checkpoint path] --mask-path [mask file path]

Note that: for fine-tuning a pruned model, run :githublink:`basic_pruners_torch.py <examples/model_compress/pruning/basic_pruners_torch.py>` first to get the mask file, then pass the mask path as argument to the script.