"src/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "2b300395b30fcabfd092c2feeddf9d6bc86f38f8"
Unverified Commit 7fcd5172 authored by Cjkkkk's avatar Cjkkkk Committed by GitHub
Browse files

update examples (#2082)

parent 585b5c53
...@@ -41,7 +41,7 @@ def test(model, device, test_loader): ...@@ -41,7 +41,7 @@ def test(model, device, test_loader):
def main(): def main():
torch.manual_seed(0) torch.manual_seed(0)
device = torch.device('cuda') device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=True, download=True, datasets.CIFAR10('./data.cifar10', train=True, download=True,
transform=transforms.Compose([ transform=transforms.Compose([
......
...@@ -105,7 +105,7 @@ def adjust_learning_rate(optimizer, epoch): ...@@ -105,7 +105,7 @@ def adjust_learning_rate(optimizer, epoch):
def main(): def main():
torch.manual_seed(0) torch.manual_seed(0)
device = torch.device('cuda') device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=True, download=True, datasets.CIFAR10('./data.cifar10', train=True, download=True,
transform=transforms.Compose([ transform=transforms.Compose([
......
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import DoReFaQuantizer
class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
self.relu1 = torch.nn.ReLU6()
self.relu2 = torch.nn.ReLU6()
self.relu3 = torch.nn.ReLU6()
def forward(self, x):
x = self.relu1(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = self.relu2(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def train(model, quantizer, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, 100 * correct / len(test_loader.dataset)))
def main():
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, transform=trans),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=trans),
batch_size=1000, shuffle=True)
model = Mnist()
model = model.to(device)
configure_list = [{
'quant_types': ['weight'],
'quant_bits': {
'weight': 8,
}, # you can just use `int` here because all `quan_types` share same bits length, see config for `ReLu6` below.
'op_types':['Conv2d', 'Linear']
}]
quantizer = DoReFaQuantizer(model, configure_list)
quantizer.compress()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.5)
for epoch in range(10):
print('# Epoch {} #'.format(epoch))
train(model, quantizer, device, train_loader, optimizer)
test(model, device, test_loader)
if __name__ == '__main__':
main()
\ No newline at end of file
...@@ -41,7 +41,7 @@ def test(model, device, test_loader): ...@@ -41,7 +41,7 @@ def test(model, device, test_loader):
def main(): def main():
torch.manual_seed(0) torch.manual_seed(0)
device = torch.device('cuda') device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=True, download=True, datasets.CIFAR10('./data.cifar10', train=True, download=True,
transform=transforms.Compose([ transform=transforms.Compose([
......
import math import math
import os
import argparse import argparse
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -48,7 +49,7 @@ def main(): ...@@ -48,7 +49,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
torch.manual_seed(0) torch.manual_seed(0)
device = torch.device('cuda') device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=True, download=True, datasets.CIFAR10('./data.cifar10', train=True, download=True,
transform=transforms.Compose([ transform=transforms.Compose([
...@@ -79,10 +80,11 @@ def main(): ...@@ -79,10 +80,11 @@ def main():
test(model, device, test_loader) test(model, device, test_loader)
lr_scheduler.step(epoch) lr_scheduler.step(epoch)
torch.save(model.state_dict(), 'vgg16_cifar10.pth') torch.save(model.state_dict(), 'vgg16_cifar10.pth')
else:
assert os.path.isfile('vgg16_cifar10.pth'), "can not find checkpoint 'vgg16_cifar10.pth'"
model.load_state_dict(torch.load('vgg16_cifar10.pth'))
# Test base model accuracy # Test base model accuracy
print('=' * 10 + 'Test on the original model' + '=' * 10) print('=' * 10 + 'Test on the original model' + '=' * 10)
model.load_state_dict(torch.load('vgg16_cifar10.pth'))
test(model, device, test_loader) test(model, device, test_loader)
# top1 = 93.51% # top1 = 93.51%
......
...@@ -56,7 +56,7 @@ def test(model, device, test_loader): ...@@ -56,7 +56,7 @@ def test(model, device, test_loader):
def main(): def main():
torch.manual_seed(0) torch.manual_seed(0)
device = torch.device('cpu') device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
...@@ -67,7 +67,6 @@ def main(): ...@@ -67,7 +67,6 @@ def main():
batch_size=1000, shuffle=True) batch_size=1000, shuffle=True)
model = Mnist() model = Mnist()
'''you can change this to DoReFaQuantizer to implement it '''you can change this to DoReFaQuantizer to implement it
DoReFaQuantizer(configure_list).compress(model) DoReFaQuantizer(configure_list).compress(model)
''' '''
...@@ -86,6 +85,7 @@ def main(): ...@@ -86,6 +85,7 @@ def main():
quantizer = QAT_Quantizer(model, configure_list) quantizer = QAT_Quantizer(model, configure_list)
quantizer.compress() quantizer.compress()
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
for epoch in range(10): for epoch in range(10):
print('# Epoch {} #'.format(epoch)) print('# Epoch {} #'.format(epoch))
......
...@@ -72,7 +72,7 @@ def test(model, device, test_loader): ...@@ -72,7 +72,7 @@ def test(model, device, test_loader):
def main(): def main():
torch.manual_seed(0) torch.manual_seed(0)
device = torch.device('cpu') device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
...@@ -83,6 +83,7 @@ def main(): ...@@ -83,6 +83,7 @@ def main():
batch_size=1000, shuffle=True) batch_size=1000, shuffle=True)
model = Mnist() model = Mnist()
model.to(device)
model.print_conv_filter_sparsity() model.print_conv_filter_sparsity()
configure_list = [{ configure_list = [{
...@@ -92,7 +93,7 @@ def main(): ...@@ -92,7 +93,7 @@ def main():
pruner = FPGMPruner(model, configure_list) pruner = FPGMPruner(model, configure_list)
pruner.compress() pruner.compress()
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
for epoch in range(10): for epoch in range(10):
pruner.update_epoch(epoch) pruner.update_epoch(epoch)
......
...@@ -55,7 +55,7 @@ def test(model, device, test_loader): ...@@ -55,7 +55,7 @@ def test(model, device, test_loader):
def main(): def main():
torch.manual_seed(0) torch.manual_seed(0)
device = torch.device('cuda') device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
......
...@@ -49,7 +49,7 @@ def test(model, device, test_loader): ...@@ -49,7 +49,7 @@ def test(model, device, test_loader):
def main(): def main():
torch.manual_seed(0) torch.manual_seed(0)
device = torch.device('cuda') device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=True, download=True, datasets.CIFAR10('./data.cifar10', train=True, download=True,
transform=transforms.Compose([ transform=transforms.Compose([
......
import math import math
import os
import argparse import argparse
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -57,7 +58,7 @@ def main(): ...@@ -57,7 +58,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
torch.manual_seed(0) torch.manual_seed(0)
device = torch.device('cuda') device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=True, download=True, datasets.CIFAR10('./data.cifar10', train=True, download=True,
transform=transforms.Compose([ transform=transforms.Compose([
...@@ -90,10 +91,11 @@ def main(): ...@@ -90,10 +91,11 @@ def main():
train(model, device, train_loader, optimizer, True) train(model, device, train_loader, optimizer, True)
test(model, device, test_loader) test(model, device, test_loader)
torch.save(model.state_dict(), 'vgg19_cifar10.pth') torch.save(model.state_dict(), 'vgg19_cifar10.pth')
else:
assert os.path.isfile('vgg19_cifar10.pth'), "can not find checkpoint 'vgg19_cifar10.pth'"
model.load_state_dict(torch.load('vgg19_cifar10.pth'))
# Test base model accuracy # Test base model accuracy
print('=' * 10 + 'Test the original model' + '=' * 10) print('=' * 10 + 'Test the original model' + '=' * 10)
model.load_state_dict(torch.load('vgg19_cifar10.pth'))
test(model, device, test_loader) test(model, device, test_loader)
# top1 = 93.60% # top1 = 93.60%
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment