"test/ut/vscode:/vscode.git/clone" did not exist on "9444e27508ab9d6e4184436af2661cd4efd387bd"
Unverified Commit 4784cc6c authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Merge pull request #3302 from microsoft/v2.0-merge

Merge branch v2.0 into master (no squash)
parents 25db55ca 349ead41
...@@ -119,3 +119,4 @@ def main(): ...@@ -119,3 +119,4 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
main() main()
\ No newline at end of file
...@@ -28,13 +28,17 @@ class fc1(nn.Module): ...@@ -28,13 +28,17 @@ class fc1(nn.Module):
def train(model, train_loader, optimizer, criterion): def train(model, train_loader, optimizer, criterion):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.train() model.train()
for imgs, targets in train_loader: for batch_idx, (imgs, targets) in enumerate(train_loader):
optimizer.zero_grad() optimizer.zero_grad()
imgs, targets = imgs.to(device), targets.to(device) imgs, targets = imgs.to(device), targets.to(device)
output = model(imgs) output = model(imgs)
train_loss = criterion(output, targets) train_loss = criterion(output, targets)
train_loss.backward() train_loss.backward()
optimizer.step() optimizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(
100 * batch_idx / len(train_loader), train_loss.item()))
return train_loss.item() return train_loss.item()
def test(model, test_loader, criterion): def test(model, test_loader, criterion):
......
...@@ -143,3 +143,4 @@ def main(): ...@@ -143,3 +143,4 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
main() main()
\ No newline at end of file
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner
from knowledge_distill.knowledge_distill import KnowledgeDistill
from models.cifar10.vgg import VGG
def train(model, device, train_loader, optimizer, kd=None):
alpha = 1
beta = 0.8
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
student_loss = F.cross_entropy(output, target)
if kd is not None:
kd_loss = kd.loss(data=data, student_out=output)
loss = alpha * student_loss + beta * kd_loss
else:
loss = student_loss
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
acc = 100 * correct / len(test_loader.dataset)
print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, acc))
return acc
def main():
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=True, download=True,
transform=transforms.Compose([
transforms.Pad(4),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=200, shuffle=False)
model = VGG(depth=16)
model.to(device)
# Train the base VGG-16 model
print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 160, 0)
for epoch in range(160):
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
lr_scheduler.step(epoch)
torch.save(model.state_dict(), 'vgg16_cifar10.pth')
# Test base model accuracy
print('=' * 10 + 'Test on the original model' + '=' * 10)
model.load_state_dict(torch.load('vgg16_cifar10.pth'))
test(model, device, test_loader)
# top1 = 93.51%
# Pruning Configuration, all convolution layers are pruned out 80% filters according to the L1 norm
configure_list = [{
'sparsity': 0.8,
'op_types': ['Conv2d'],
}]
# Prune model and test accuracy without fine tuning.
print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
pruner = L1FilterPruner(model, configure_list)
model = pruner.compress()
test(model, device, test_loader)
# top1 = 10.00%
# Fine tune the pruned model for 40 epochs and test accuracy
print('=' * 10 + 'Fine tuning' + '=' * 10)
optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
best_top1 = 0
kd_teacher_model = VGG(depth=16)
kd_teacher_model.to(device)
kd_teacher_model.load_state_dict(torch.load('vgg16_cifar10.pth'))
kd = KnowledgeDistill(kd_teacher_model, kd_T=5)
for epoch in range(40):
pruner.update_epoch(epoch)
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer_finetune, kd)
top1 = test(model, device, test_loader)
if top1 > best_top1:
best_top1 = top1
# Export the best model, 'model_path' stores state_dict of the pruned model,
# mask_path stores mask_dict of the pruned model
pruner.export_model(model_path='pruned_vgg16_cifar10.pth', mask_path='mask_vgg16_cifar10.pth')
# Test the exported model
print('=' * 10 + 'Test on the pruned model after fine tune' + '=' * 10)
new_model = VGG(depth=16)
new_model.to(device)
new_model.load_state_dict(torch.load('pruned_vgg16_cifar10.pth'))
test(new_model, device, test_loader)
# top1 = 85.43% with kd, top1 = 85.04% without kd,
if __name__ == '__main__':
main()
...@@ -14,7 +14,7 @@ import torch.nn as nn ...@@ -14,7 +14,7 @@ import torch.nn as nn
from genotypes import Genotype from genotypes import Genotype
from ops import PRIMITIVES from ops import PRIMITIVES
from nni.nas.pytorch.cdarts.utils import * from nni.algorithms.nas.pytorch.cdarts.utils import *
def get_logger(file_path): def get_logger(file_path):
......
...@@ -7,7 +7,7 @@ from tensorflow.keras.optimizers import SGD ...@@ -7,7 +7,7 @@ from tensorflow.keras.optimizers import SGD
import nni import nni
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice from nni.nas.tensorflow.mutables import LayerChoice, InputChoice
from nni.nas.tensorflow.classic_nas import get_and_apply_next_architecture from nni.algorithms.nas.tensorflow.classic_nas import get_and_apply_next_architecture
tf.get_logger().setLevel('ERROR') tf.get_logger().setLevel('ERROR')
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from tensorflow.keras.losses import Reduction, SparseCategoricalCrossentropy from tensorflow.keras.losses import Reduction, SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import SGD from tensorflow.keras.optimizers import SGD
from nni.nas.tensorflow import enas from nni.algorithms.nas.tensorflow import enas
import datasets import datasets
from macro import GeneralNetwork from macro import GeneralNetwork
......
...@@ -8,7 +8,7 @@ from tensorflow.keras.losses import Reduction, SparseCategoricalCrossentropy ...@@ -8,7 +8,7 @@ from tensorflow.keras.losses import Reduction, SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import SGD from tensorflow.keras.optimizers import SGD
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice from nni.nas.tensorflow.mutables import LayerChoice, InputChoice
from nni.nas.tensorflow.enas import EnasTrainer from nni.algorithms.nas.tensorflow.enas import EnasTrainer
class Net(Model): class Net(Model):
...@@ -55,7 +55,7 @@ class Net(Model): ...@@ -55,7 +55,7 @@ class Net(Model):
def accuracy(truth, logits): def accuracy(truth, logits):
truth = tf.reshape(truth, -1) truth = tf.reshape(truth, (-1, ))
predicted = tf.cast(tf.math.argmax(logits, axis=1), truth.dtype) predicted = tf.cast(tf.math.argmax(logits, axis=1), truth.dtype)
equal = tf.cast(predicted == truth, tf.int32) equal = tf.cast(predicted == truth, tf.int32)
return tf.math.reduce_sum(equal).numpy() / equal.shape[0] return tf.math.reduce_sum(equal).numpy() / equal.shape[0]
......
import json
import logging import logging
import os import os
import sys import sys
...@@ -102,6 +103,7 @@ if __name__ == "__main__": ...@@ -102,6 +103,7 @@ if __name__ == "__main__":
log_frequency=10) log_frequency=10)
trainer.fit() trainer.fit()
print('Final architecture:', trainer.export()) print('Final architecture:', trainer.export())
json.dump(trainer.export(), open('checkpoint.json', 'w'))
elif args.train_mode == 'search_v1': elif args.train_mode == 'search_v1':
# this is architecture search # this is architecture search
logger.info('Creating ProxylessNasTrainer...') logger.info('Creating ProxylessNasTrainer...')
......
...@@ -85,7 +85,7 @@ def accuracy(output, target, topk=(1,)): ...@@ -85,7 +85,7 @@ def accuracy(output, target, topk=(1,)):
res = dict() res = dict()
for k in topk: for k in topk:
correct_k = correct[:k].view(-1).float().sum(0) correct_k = correct[:k].reshape(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res return res
......
...@@ -10,7 +10,7 @@ import torch.nn as nn ...@@ -10,7 +10,7 @@ import torch.nn as nn
import datasets import datasets
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
from nni.nas.pytorch.darts import DartsTrainer from nni.algorithms.nas.pytorch.darts import DartsTrainer
from utils import accuracy from utils import accuracy
from nni.nas.pytorch.search_space_zoo import DartsCell from nni.nas.pytorch.search_space_zoo import DartsCell
......
...@@ -8,7 +8,7 @@ from torchvision import transforms ...@@ -8,7 +8,7 @@ from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from nni.nas.pytorch import mutables from nni.nas.pytorch import mutables
from nni.nas.pytorch import enas from nni.algorithms.nas.pytorch import enas
from utils import accuracy, reward_accuracy from utils import accuracy, reward_accuracy
from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint, from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint,
LRSchedulerCallback) LRSchedulerCallback)
......
...@@ -7,7 +7,7 @@ from argparse import ArgumentParser ...@@ -7,7 +7,7 @@ from argparse import ArgumentParser
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from nni.nas.pytorch import enas from nni.algorithms.nas.pytorch import enas
from utils import accuracy, reward_accuracy from utils import accuracy, reward_accuracy
from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint, from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint,
LRSchedulerCallback) LRSchedulerCallback)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment