Unverified Commit 4784cc6c authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Merge pull request #3302 from microsoft/v2.0-merge

Merge branch v2.0 into master (no squash)
parents 25db55ca 349ead41
......@@ -119,3 +119,4 @@ def main():
if __name__ == '__main__':
main()
\ No newline at end of file
......@@ -28,13 +28,17 @@ class fc1(nn.Module):
def train(model, train_loader, optimizer, criterion):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.train()
for imgs, targets in train_loader:
for batch_idx, (imgs, targets) in enumerate(train_loader):
optimizer.zero_grad()
imgs, targets = imgs.to(device), targets.to(device)
output = model(imgs)
train_loss = criterion(output, targets)
train_loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(
100 * batch_idx / len(train_loader), train_loss.item()))
return train_loss.item()
def test(model, test_loader, criterion):
......
......@@ -143,3 +143,4 @@ def main():
if __name__ == '__main__':
main()
\ No newline at end of file
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner
from knowledge_distill.knowledge_distill import KnowledgeDistill
from models.cifar10.vgg import VGG
def train(model, device, train_loader, optimizer, kd=None):
alpha = 1
beta = 0.8
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
student_loss = F.cross_entropy(output, target)
if kd is not None:
kd_loss = kd.loss(data=data, student_out=output)
loss = alpha * student_loss + beta * kd_loss
else:
loss = student_loss
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
acc = 100 * correct / len(test_loader.dataset)
print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, acc))
return acc
def main():
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=True, download=True,
transform=transforms.Compose([
transforms.Pad(4),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=200, shuffle=False)
model = VGG(depth=16)
model.to(device)
# Train the base VGG-16 model
print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 160, 0)
for epoch in range(160):
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
lr_scheduler.step(epoch)
torch.save(model.state_dict(), 'vgg16_cifar10.pth')
# Test base model accuracy
print('=' * 10 + 'Test on the original model' + '=' * 10)
model.load_state_dict(torch.load('vgg16_cifar10.pth'))
test(model, device, test_loader)
# top1 = 93.51%
# Pruning Configuration, all convolution layers are pruned out 80% filters according to the L1 norm
configure_list = [{
'sparsity': 0.8,
'op_types': ['Conv2d'],
}]
# Prune model and test accuracy without fine tuning.
print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
pruner = L1FilterPruner(model, configure_list)
model = pruner.compress()
test(model, device, test_loader)
# top1 = 10.00%
# Fine tune the pruned model for 40 epochs and test accuracy
print('=' * 10 + 'Fine tuning' + '=' * 10)
optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
best_top1 = 0
kd_teacher_model = VGG(depth=16)
kd_teacher_model.to(device)
kd_teacher_model.load_state_dict(torch.load('vgg16_cifar10.pth'))
kd = KnowledgeDistill(kd_teacher_model, kd_T=5)
for epoch in range(40):
pruner.update_epoch(epoch)
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer_finetune, kd)
top1 = test(model, device, test_loader)
if top1 > best_top1:
best_top1 = top1
# Export the best model, 'model_path' stores state_dict of the pruned model,
# mask_path stores mask_dict of the pruned model
pruner.export_model(model_path='pruned_vgg16_cifar10.pth', mask_path='mask_vgg16_cifar10.pth')
# Test the exported model
print('=' * 10 + 'Test on the pruned model after fine tune' + '=' * 10)
new_model = VGG(depth=16)
new_model.to(device)
new_model.load_state_dict(torch.load('pruned_vgg16_cifar10.pth'))
test(new_model, device, test_loader)
# top1 = 85.43% with kd, top1 = 85.04% without kd,
if __name__ == '__main__':
main()
......@@ -14,7 +14,7 @@ import torch.nn as nn
from genotypes import Genotype
from ops import PRIMITIVES
from nni.nas.pytorch.cdarts.utils import *
from nni.algorithms.nas.pytorch.cdarts.utils import *
def get_logger(file_path):
......
......@@ -7,7 +7,7 @@ from tensorflow.keras.optimizers import SGD
import nni
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice
from nni.nas.tensorflow.classic_nas import get_and_apply_next_architecture
from nni.algorithms.nas.tensorflow.classic_nas import get_and_apply_next_architecture
tf.get_logger().setLevel('ERROR')
......
......@@ -5,7 +5,7 @@
from tensorflow.keras.losses import Reduction, SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import SGD
from nni.nas.tensorflow import enas
from nni.algorithms.nas.tensorflow import enas
import datasets
from macro import GeneralNetwork
......
......@@ -8,7 +8,7 @@ from tensorflow.keras.losses import Reduction, SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import SGD
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice
from nni.nas.tensorflow.enas import EnasTrainer
from nni.algorithms.nas.tensorflow.enas import EnasTrainer
class Net(Model):
......@@ -55,7 +55,7 @@ class Net(Model):
def accuracy(truth, logits):
truth = tf.reshape(truth, -1)
truth = tf.reshape(truth, (-1, ))
predicted = tf.cast(tf.math.argmax(logits, axis=1), truth.dtype)
equal = tf.cast(predicted == truth, tf.int32)
return tf.math.reduce_sum(equal).numpy() / equal.shape[0]
......
import json
import logging
import os
import sys
......@@ -102,6 +103,7 @@ if __name__ == "__main__":
log_frequency=10)
trainer.fit()
print('Final architecture:', trainer.export())
json.dump(trainer.export(), open('checkpoint.json', 'w'))
elif args.train_mode == 'search_v1':
# this is architecture search
logger.info('Creating ProxylessNasTrainer...')
......
......@@ -85,7 +85,7 @@ def accuracy(output, target, topk=(1,)):
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
correct_k = correct[:k].reshape(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
......
......@@ -10,7 +10,7 @@ import torch.nn as nn
import datasets
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
from nni.nas.pytorch.darts import DartsTrainer
from nni.algorithms.nas.pytorch.darts import DartsTrainer
from utils import accuracy
from nni.nas.pytorch.search_space_zoo import DartsCell
......
......@@ -8,7 +8,7 @@ from torchvision import transforms
from torchvision.datasets import CIFAR10
from nni.nas.pytorch import mutables
from nni.nas.pytorch import enas
from nni.algorithms.nas.pytorch import enas
from utils import accuracy, reward_accuracy
from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint,
LRSchedulerCallback)
......
......@@ -7,7 +7,7 @@ from argparse import ArgumentParser
from torchvision import transforms
from torchvision.datasets import CIFAR10
from nni.nas.pytorch import enas
from nni.algorithms.nas.pytorch import enas
from utils import accuracy, reward_accuracy
from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint,
LRSchedulerCallback)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment