Unverified Commit b122c63d authored by colorjam's avatar colorjam Committed by GitHub
Browse files

Update finetuning with kd example (#3412)

parent 969f0d99
...@@ -7,24 +7,21 @@ Run basic_pruners_torch.py first to get the masks of the pruned model. Then pass ...@@ -7,24 +7,21 @@ Run basic_pruners_torch.py first to get the masks of the pruned model. Then pass
''' '''
import argparse import argparse
import os import os
import time import time
import argparse from copy import deepcopy
import nni
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from torch.optim.lr_scheduler import StepLR, MultiStepLR from nni.compression.pytorch import ModelSpeedup
from torch.optim.lr_scheduler import MultiStepLR, StepLR
from torchvision import datasets, transforms from torchvision import datasets, transforms
from copy import deepcopy
from models.mnist.lenet import LeNet
from models.cifar10.vgg import VGG
from basic_pruners_torch import get_data from basic_pruners_torch import get_data
from models.cifar10.vgg import VGG
import nni from models.mnist.lenet import LeNet
from nni.compression.pytorch import ModelSpeedup, get_dummy_input
class DistillKL(nn.Module): class DistillKL(nn.Module):
"""Distilling the Knowledge in a Neural Network""" """Distilling the Knowledge in a Neural Network"""
...@@ -38,6 +35,13 @@ class DistillKL(nn.Module): ...@@ -38,6 +35,13 @@ class DistillKL(nn.Module):
loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0] loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
return loss return loss
def get_dummy_input(args, device):
if args.dataset == 'mnist':
dummy_input = torch.randn([args.test_batch_size, 1, 28, 28]).to(device)
elif args.dataset in ['cifar10', 'imagenet']:
dummy_input = torch.randn([args.test_batch_size, 3, 32, 32]).to(device)
return dummy_input
def get_model_optimizer_scheduler(args, device, test_loader, criterion): def get_model_optimizer_scheduler(args, device, test_loader, criterion):
if args.model == 'LeNet': if args.model == 'LeNet':
model = LeNet().to(device) model = LeNet().to(device)
...@@ -51,7 +55,6 @@ def get_model_optimizer_scheduler(args, device, test_loader, criterion): ...@@ -51,7 +55,6 @@ def get_model_optimizer_scheduler(args, device, test_loader, criterion):
# In this example, we set the architecture of teacher and student to be the same. It is feasible to set a different teacher architecture. # In this example, we set the architecture of teacher and student to be the same. It is feasible to set a different teacher architecture.
if args.teacher_model_dir is None: if args.teacher_model_dir is None:
raise NotImplementedError('please load pretrained teacher model first') raise NotImplementedError('please load pretrained teacher model first')
else: else:
model.load_state_dict(torch.load(args.teacher_model_dir)) model.load_state_dict(torch.load(args.teacher_model_dir))
best_acc = test(args, model, device, criterion, test_loader) best_acc = test(args, model, device, criterion, test_loader)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment