Unverified Commit b122c63d authored by colorjam's avatar colorjam Committed by GitHub
Browse files

Update finetuning with kd example (#3412)

parent 969f0d99
......@@ -7,24 +7,21 @@ Run basic_pruners_torch.py first to get the masks of the pruned model. Then pass
'''
import argparse
import os
import time
import argparse
from copy import deepcopy
import nni
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR, MultiStepLR
from nni.compression.pytorch import ModelSpeedup
from torch.optim.lr_scheduler import MultiStepLR, StepLR
from torchvision import datasets, transforms
from copy import deepcopy
from models.mnist.lenet import LeNet
from models.cifar10.vgg import VGG
from basic_pruners_torch import get_data
import nni
from nni.compression.pytorch import ModelSpeedup, get_dummy_input
from models.cifar10.vgg import VGG
from models.mnist.lenet import LeNet
class DistillKL(nn.Module):
"""Distilling the Knowledge in a Neural Network"""
......@@ -38,6 +35,13 @@ class DistillKL(nn.Module):
loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
return loss
def get_dummy_input(args, device):
if args.dataset == 'mnist':
dummy_input = torch.randn([args.test_batch_size, 1, 28, 28]).to(device)
elif args.dataset in ['cifar10', 'imagenet']:
dummy_input = torch.randn([args.test_batch_size, 3, 32, 32]).to(device)
return dummy_input
def get_model_optimizer_scheduler(args, device, test_loader, criterion):
if args.model == 'LeNet':
model = LeNet().to(device)
......@@ -51,7 +55,6 @@ def get_model_optimizer_scheduler(args, device, test_loader, criterion):
# In this example, we set the architecture of teacher and student to be the same. It is feasible to set a different teacher architecture.
if args.teacher_model_dir is None:
raise NotImplementedError('please load pretrained teacher model first')
else:
model.load_state_dict(torch.load(args.teacher_model_dir))
best_acc = test(args, model, device, criterion, test_loader)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment