Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
e5cb4ed8
"examples/vscode:/vscode.git/clone" did not exist on "ae05050db9d37d5af48a6cd0d6510a5ffb1c1cd4"
Commit
e5cb4ed8
authored
Nov 29, 2019
by
chicm-ms
Browse files
refine model compression examples (#1804)
parent
962d9aee
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
10 additions
and
351 deletions
+10
-351
examples/model_compress/fpgm_torch_mnist.py
examples/model_compress/fpgm_torch_mnist.py
+7
-9
examples/model_compress/lottery_torch_mnist_fc.py
examples/model_compress/lottery_torch_mnist_fc.py
+3
-3
examples/model_compress/main_tf_pruner.py
examples/model_compress/main_tf_pruner.py
+0
-132
examples/model_compress/main_tf_quantizer.py
examples/model_compress/main_tf_quantizer.py
+0
-119
examples/model_compress/main_torch_pruner.py
examples/model_compress/main_torch_pruner.py
+0
-2
examples/model_compress/main_torch_quantizer.py
examples/model_compress/main_torch_quantizer.py
+0
-86
No files found.
examples/model_compress/fpgm_torch_mnist.py
View file @
e5cb4ed8
from
nni.compression.torch
import
FPGMPruner
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torchvision
import
datasets
,
transforms
from
torchvision
import
datasets
,
transforms
from
nni.compression.torch
import
FPGMPruner
class
Mnist
(
torch
.
nn
.
Module
):
class
Mnist
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -23,8 +22,8 @@ class Mnist(torch.nn.Module):
...
@@ -23,8 +22,8 @@ class Mnist(torch.nn.Module):
return
F
.
log_softmax
(
x
,
dim
=
1
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
def
_get_conv_weight_sparsity
(
self
,
conv_layer
):
def
_get_conv_weight_sparsity
(
self
,
conv_layer
):
num_zero_filters
=
(
conv_layer
.
weight
.
data
.
sum
((
2
,
3
))
==
0
).
sum
()
num_zero_filters
=
(
conv_layer
.
weight
.
data
.
sum
((
1
,
2
,
3
))
==
0
).
sum
()
num_filters
=
conv_layer
.
weight
.
data
.
size
(
0
)
*
conv_layer
.
weight
.
data
.
size
(
1
)
num_filters
=
conv_layer
.
weight
.
data
.
size
(
0
)
return
num_zero_filters
,
num_filters
,
float
(
num_zero_filters
)
/
num_filters
return
num_zero_filters
,
num_filters
,
float
(
num_zero_filters
)
/
num_filters
def
print_conv_filter_sparsity
(
self
):
def
print_conv_filter_sparsity
(
self
):
...
@@ -41,7 +40,8 @@ def train(model, device, train_loader, optimizer):
...
@@ -41,7 +40,8 @@ def train(model, device, train_loader, optimizer):
output
=
model
(
data
)
output
=
model
(
data
)
loss
=
F
.
nll_loss
(
output
,
target
)
loss
=
F
.
nll_loss
(
output
,
target
)
if
batch_idx
%
100
==
0
:
if
batch_idx
%
100
==
0
:
print
(
'{:2.0f}% Loss {}'
.
format
(
100
*
batch_idx
/
len
(
train_loader
),
loss
.
item
()))
print
(
'{:.2f}% Loss {:.4f}'
.
format
(
100
*
batch_idx
/
len
(
train_loader
),
loss
.
item
()))
if
batch_idx
==
0
:
model
.
print_conv_filter_sparsity
()
model
.
print_conv_filter_sparsity
()
loss
.
backward
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
step
()
...
@@ -59,7 +59,7 @@ def test(model, device, test_loader):
...
@@ -59,7 +59,7 @@ def test(model, device, test_loader):
correct
+=
pred
.
eq
(
target
.
view_as
(
pred
)).
sum
().
item
()
correct
+=
pred
.
eq
(
target
.
view_as
(
pred
)).
sum
().
item
()
test_loss
/=
len
(
test_loader
.
dataset
)
test_loss
/=
len
(
test_loader
.
dataset
)
print
(
'Loss: {} Accuracy: {}%)
\n
'
.
format
(
print
(
'Loss: {
:.4f
} Accuracy: {}%)
\n
'
.
format
(
test_loss
,
100
*
correct
/
len
(
test_loader
.
dataset
)))
test_loss
,
100
*
correct
/
len
(
test_loader
.
dataset
)))
...
@@ -78,9 +78,6 @@ def main():
...
@@ -78,9 +78,6 @@ def main():
model
=
Mnist
()
model
=
Mnist
()
model
.
print_conv_filter_sparsity
()
model
.
print_conv_filter_sparsity
()
'''you can change this to LevelPruner to implement it
pruner = LevelPruner(configure_list)
'''
configure_list
=
[{
configure_list
=
[{
'sparsity'
:
0.5
,
'sparsity'
:
0.5
,
'op_types'
:
[
'Conv2d'
]
'op_types'
:
[
'Conv2d'
]
...
@@ -96,6 +93,7 @@ def main():
...
@@ -96,6 +93,7 @@ def main():
train
(
model
,
device
,
train_loader
,
optimizer
)
train
(
model
,
device
,
train_loader
,
optimizer
)
test
(
model
,
device
,
test_loader
)
test
(
model
,
device
,
test_loader
)
pruner
.
export_model
(
'model.pth'
,
'mask.pth'
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
examples/model_compress/lottery_torch_mnist_fc.py
View file @
e5cb4ed8
...
@@ -26,7 +26,7 @@ class fc1(nn.Module):
...
@@ -26,7 +26,7 @@ class fc1(nn.Module):
def
train
(
model
,
train_loader
,
optimizer
,
criterion
):
def
train
(
model
,
train_loader
,
optimizer
,
criterion
):
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
model
.
train
()
model
.
train
()
for
batch_idx
,
(
imgs
,
targets
)
in
enumerate
(
train_loader
)
:
for
imgs
,
targets
in
train_loader
:
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
imgs
,
targets
=
imgs
.
to
(
device
),
targets
.
to
(
device
)
imgs
,
targets
=
imgs
.
to
(
device
),
targets
.
to
(
device
)
output
=
model
(
imgs
)
output
=
model
(
imgs
)
...
@@ -64,7 +64,7 @@ if __name__ == '__main__':
...
@@ -64,7 +64,7 @@ if __name__ == '__main__':
criterion
=
nn
.
CrossEntropyLoss
()
criterion
=
nn
.
CrossEntropyLoss
()
configure_list
=
[{
configure_list
=
[{
'prune_iterations'
:
10
,
'prune_iterations'
:
5
,
'sparsity'
:
0.96
,
'sparsity'
:
0.96
,
'op_types'
:
[
'default'
]
'op_types'
:
[
'default'
]
}]
}]
...
@@ -75,7 +75,7 @@ if __name__ == '__main__':
...
@@ -75,7 +75,7 @@ if __name__ == '__main__':
pruner
.
prune_iteration_start
()
pruner
.
prune_iteration_start
()
loss
=
0
loss
=
0
accuracy
=
0
accuracy
=
0
for
epoch
in
range
(
5
0
):
for
epoch
in
range
(
1
0
):
loss
=
train
(
model
,
train_loader
,
optimizer
,
criterion
)
loss
=
train
(
model
,
train_loader
,
optimizer
,
criterion
)
accuracy
=
test
(
model
,
test_loader
,
criterion
)
accuracy
=
test
(
model
,
test_loader
,
criterion
)
print
(
'current epoch: {0}, loss: {1}, accuracy: {2}'
.
format
(
epoch
,
loss
,
accuracy
))
print
(
'current epoch: {0}, loss: {1}, accuracy: {2}'
.
format
(
epoch
,
loss
,
accuracy
))
...
...
examples/model_compress/main_tf_pruner.py
deleted
100644 → 0
View file @
962d9aee
from
nni.compression.tensorflow
import
AGP_Pruner
import
tensorflow
as
tf
from
tensorflow.examples.tutorials.mnist
import
input_data
def
weight_variable
(
shape
):
return
tf
.
Variable
(
tf
.
truncated_normal
(
shape
,
stddev
=
0.1
))
def
bias_variable
(
shape
):
return
tf
.
Variable
(
tf
.
constant
(
0.1
,
shape
=
shape
))
def
conv2d
(
x_input
,
w_matrix
):
return
tf
.
nn
.
conv2d
(
x_input
,
w_matrix
,
strides
=
[
1
,
1
,
1
,
1
],
padding
=
'SAME'
)
def
max_pool
(
x_input
,
pool_size
):
size
=
[
1
,
pool_size
,
pool_size
,
1
]
return
tf
.
nn
.
max_pool
(
x_input
,
ksize
=
size
,
strides
=
size
,
padding
=
'SAME'
)
class
Mnist
:
def
__init__
(
self
):
images
=
tf
.
placeholder
(
tf
.
float32
,
[
None
,
784
],
name
=
'input_x'
)
labels
=
tf
.
placeholder
(
tf
.
float32
,
[
None
,
10
],
name
=
'input_y'
)
keep_prob
=
tf
.
placeholder
(
tf
.
float32
,
name
=
'keep_prob'
)
self
.
images
=
images
self
.
labels
=
labels
self
.
keep_prob
=
keep_prob
self
.
train_step
=
None
self
.
accuracy
=
None
self
.
w1
=
None
self
.
b1
=
None
self
.
fcw1
=
None
self
.
cross
=
None
with
tf
.
name_scope
(
'reshape'
):
x_image
=
tf
.
reshape
(
images
,
[
-
1
,
28
,
28
,
1
])
with
tf
.
name_scope
(
'conv1'
):
w_conv1
=
weight_variable
([
5
,
5
,
1
,
32
])
self
.
w1
=
w_conv1
b_conv1
=
bias_variable
([
32
])
self
.
b1
=
b_conv1
h_conv1
=
tf
.
nn
.
relu
(
conv2d
(
x_image
,
w_conv1
)
+
b_conv1
)
with
tf
.
name_scope
(
'pool1'
):
h_pool1
=
max_pool
(
h_conv1
,
2
)
with
tf
.
name_scope
(
'conv2'
):
w_conv2
=
weight_variable
([
5
,
5
,
32
,
64
])
b_conv2
=
bias_variable
([
64
])
h_conv2
=
tf
.
nn
.
relu
(
conv2d
(
h_pool1
,
w_conv2
)
+
b_conv2
)
with
tf
.
name_scope
(
'pool2'
):
h_pool2
=
max_pool
(
h_conv2
,
2
)
with
tf
.
name_scope
(
'fc1'
):
w_fc1
=
weight_variable
([
7
*
7
*
64
,
1024
])
self
.
fcw1
=
w_fc1
b_fc1
=
bias_variable
([
1024
])
h_pool2_flat
=
tf
.
reshape
(
h_pool2
,
[
-
1
,
7
*
7
*
64
])
h_fc1
=
tf
.
nn
.
relu
(
tf
.
matmul
(
h_pool2_flat
,
w_fc1
)
+
b_fc1
)
with
tf
.
name_scope
(
'dropout'
):
h_fc1_drop
=
tf
.
nn
.
dropout
(
h_fc1
,
0.5
)
with
tf
.
name_scope
(
'fc2'
):
w_fc2
=
weight_variable
([
1024
,
10
])
b_fc2
=
bias_variable
([
10
])
y_conv
=
tf
.
matmul
(
h_fc1_drop
,
w_fc2
)
+
b_fc2
with
tf
.
name_scope
(
'loss'
):
cross_entropy
=
tf
.
reduce_mean
(
tf
.
nn
.
softmax_cross_entropy_with_logits
(
labels
=
labels
,
logits
=
y_conv
))
self
.
cross
=
cross_entropy
with
tf
.
name_scope
(
'adam_optimizer'
):
self
.
train_step
=
tf
.
train
.
AdamOptimizer
(
0.0001
).
minimize
(
cross_entropy
)
with
tf
.
name_scope
(
'accuracy'
):
correct_prediction
=
tf
.
equal
(
tf
.
argmax
(
y_conv
,
1
),
tf
.
argmax
(
labels
,
1
))
self
.
accuracy
=
tf
.
reduce_mean
(
tf
.
cast
(
correct_prediction
,
tf
.
float32
))
def
main
():
tf
.
set_random_seed
(
0
)
data
=
input_data
.
read_data_sets
(
'data'
,
one_hot
=
True
)
model
=
Mnist
()
'''you can change this to LevelPruner to implement it
pruner = LevelPruner(configure_list)
'''
configure_list
=
[{
'initial_sparsity'
:
0
,
'final_sparsity'
:
0.8
,
'start_epoch'
:
0
,
'end_epoch'
:
10
,
'frequency'
:
1
,
'op_types'
:
[
'default'
]
}]
pruner
=
AGP_Pruner
(
tf
.
get_default_graph
(),
configure_list
)
# if you want to load from yaml file
# configure_file = nni.compressors.tf_compressor._nnimc_tf._tf_default_load_configure_file('configure_example.yaml','AGPruner')
# configure_list = configure_file.get('config',[])
# pruner.load_configure(configure_list)
# you can also handle it yourself and input an configure list in json
pruner
.
compress
()
with
tf
.
Session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
for
batch_idx
in
range
(
2000
):
if
batch_idx
%
10
==
0
:
pruner
.
update_epoch
(
batch_idx
/
10
,
sess
)
batch
=
data
.
train
.
next_batch
(
2000
)
model
.
train_step
.
run
(
feed_dict
=
{
model
.
images
:
batch
[
0
],
model
.
labels
:
batch
[
1
],
model
.
keep_prob
:
0.5
})
if
batch_idx
%
10
==
0
:
test_acc
=
model
.
accuracy
.
eval
(
feed_dict
=
{
model
.
images
:
data
.
test
.
images
,
model
.
labels
:
data
.
test
.
labels
,
model
.
keep_prob
:
1.0
})
print
(
'test accuracy'
,
test_acc
)
test_acc
=
model
.
accuracy
.
eval
(
feed_dict
=
{
model
.
images
:
data
.
test
.
images
,
model
.
labels
:
data
.
test
.
labels
,
model
.
keep_prob
:
1.0
})
print
(
'final result is'
,
test_acc
)
if
__name__
==
'__main__'
:
main
()
examples/model_compress/main_tf_quantizer.py
deleted
100644 → 0
View file @
962d9aee
from
nni.compression.tensorflow
import
QAT_Quantizer
import
tensorflow
as
tf
from
tensorflow.examples.tutorials.mnist
import
input_data
def
weight_variable
(
shape
):
return
tf
.
Variable
(
tf
.
truncated_normal
(
shape
,
stddev
=
0.1
))
def
bias_variable
(
shape
):
return
tf
.
Variable
(
tf
.
constant
(
0.1
,
shape
=
shape
))
def
conv2d
(
x_input
,
w_matrix
):
return
tf
.
nn
.
conv2d
(
x_input
,
w_matrix
,
strides
=
[
1
,
1
,
1
,
1
],
padding
=
'SAME'
)
def
max_pool
(
x_input
,
pool_size
):
size
=
[
1
,
pool_size
,
pool_size
,
1
]
return
tf
.
nn
.
max_pool
(
x_input
,
ksize
=
size
,
strides
=
size
,
padding
=
'SAME'
)
class
Mnist
:
def
__init__
(
self
):
images
=
tf
.
placeholder
(
tf
.
float32
,
[
None
,
784
],
name
=
'input_x'
)
labels
=
tf
.
placeholder
(
tf
.
float32
,
[
None
,
10
],
name
=
'input_y'
)
keep_prob
=
tf
.
placeholder
(
tf
.
float32
,
name
=
'keep_prob'
)
self
.
images
=
images
self
.
labels
=
labels
self
.
keep_prob
=
keep_prob
self
.
train_step
=
None
self
.
accuracy
=
None
self
.
w1
=
None
self
.
b1
=
None
self
.
fcw1
=
None
self
.
cross
=
None
with
tf
.
name_scope
(
'reshape'
):
x_image
=
tf
.
reshape
(
images
,
[
-
1
,
28
,
28
,
1
])
with
tf
.
name_scope
(
'conv1'
):
w_conv1
=
weight_variable
([
5
,
5
,
1
,
32
])
self
.
w1
=
w_conv1
b_conv1
=
bias_variable
([
32
])
self
.
b1
=
b_conv1
h_conv1
=
tf
.
nn
.
relu
(
conv2d
(
x_image
,
w_conv1
)
+
b_conv1
)
with
tf
.
name_scope
(
'pool1'
):
h_pool1
=
max_pool
(
h_conv1
,
2
)
with
tf
.
name_scope
(
'conv2'
):
w_conv2
=
weight_variable
([
5
,
5
,
32
,
64
])
b_conv2
=
bias_variable
([
64
])
h_conv2
=
tf
.
nn
.
relu
(
conv2d
(
h_pool1
,
w_conv2
)
+
b_conv2
)
with
tf
.
name_scope
(
'pool2'
):
h_pool2
=
max_pool
(
h_conv2
,
2
)
with
tf
.
name_scope
(
'fc1'
):
w_fc1
=
weight_variable
([
7
*
7
*
64
,
1024
])
self
.
fcw1
=
w_fc1
b_fc1
=
bias_variable
([
1024
])
h_pool2_flat
=
tf
.
reshape
(
h_pool2
,
[
-
1
,
7
*
7
*
64
])
h_fc1
=
tf
.
nn
.
relu
(
tf
.
matmul
(
h_pool2_flat
,
w_fc1
)
+
b_fc1
)
with
tf
.
name_scope
(
'dropout'
):
h_fc1_drop
=
tf
.
nn
.
dropout
(
h_fc1
,
0.5
)
with
tf
.
name_scope
(
'fc2'
):
w_fc2
=
weight_variable
([
1024
,
10
])
b_fc2
=
bias_variable
([
10
])
y_conv
=
tf
.
matmul
(
h_fc1_drop
,
w_fc2
)
+
b_fc2
with
tf
.
name_scope
(
'loss'
):
cross_entropy
=
tf
.
reduce_mean
(
tf
.
nn
.
softmax_cross_entropy_with_logits
(
labels
=
labels
,
logits
=
y_conv
))
self
.
cross
=
cross_entropy
with
tf
.
name_scope
(
'adam_optimizer'
):
self
.
train_step
=
tf
.
train
.
AdamOptimizer
(
0.0001
).
minimize
(
cross_entropy
)
with
tf
.
name_scope
(
'accuracy'
):
correct_prediction
=
tf
.
equal
(
tf
.
argmax
(
y_conv
,
1
),
tf
.
argmax
(
labels
,
1
))
self
.
accuracy
=
tf
.
reduce_mean
(
tf
.
cast
(
correct_prediction
,
tf
.
float32
))
def
main
():
tf
.
set_random_seed
(
0
)
data
=
input_data
.
read_data_sets
(
'data'
,
one_hot
=
True
)
model
=
Mnist
()
'''you can change this to DoReFaQuantizer to implement it
DoReFaQuantizer(configure_list).compress(tf.get_default_graph())
'''
configure_list
=
[{
'q_bits'
:
8
,
'op_types'
:[
'default'
]}]
quantizer
=
QAT_Quantizer
(
tf
.
get_default_graph
(),
configure_list
)
quantizer
.
compress
()
# you can also use compress(model) or compress_default_graph()
# method like QATquantizer(q_bits = 8).compress_default_graph()
with
tf
.
Session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
for
batch_idx
in
range
(
2000
):
batch
=
data
.
train
.
next_batch
(
2000
)
model
.
train_step
.
run
(
feed_dict
=
{
model
.
images
:
batch
[
0
],
model
.
labels
:
batch
[
1
],
model
.
keep_prob
:
0.5
})
if
batch_idx
%
10
==
0
:
test_acc
=
model
.
accuracy
.
eval
(
feed_dict
=
{
model
.
images
:
data
.
test
.
images
,
model
.
labels
:
data
.
test
.
labels
,
model
.
keep_prob
:
1.0
})
print
(
'test accuracy'
,
test_acc
)
test_acc
=
model
.
accuracy
.
eval
(
feed_dict
=
{
model
.
images
:
data
.
test
.
images
,
model
.
labels
:
data
.
test
.
labels
,
model
.
keep_prob
:
1.0
})
print
(
'final result is'
,
test_acc
)
if
__name__
==
'__main__'
:
main
()
examples/model_compress/main_torch_pruner.py
View file @
e5cb4ed8
...
@@ -82,8 +82,6 @@ def main():
...
@@ -82,8 +82,6 @@ def main():
pruner
=
AGP_Pruner
(
model
,
configure_list
)
pruner
=
AGP_Pruner
(
model
,
configure_list
)
model
=
pruner
.
compress
()
model
=
pruner
.
compress
()
# you can also use compress(model) method
# like that pruner.compress(model)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
for
epoch
in
range
(
10
):
for
epoch
in
range
(
10
):
...
...
examples/model_compress/main_torch_quantizer.py
deleted
100644 → 0
View file @
962d9aee
from
nni.compression.torch
import
QAT_Quantizer
import
torch
import
torch.nn.functional
as
F
from
torchvision
import
datasets
,
transforms
class
Mnist
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
50
,
500
)
self
.
fc2
=
torch
.
nn
.
Linear
(
500
,
10
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
conv1
(
x
))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
relu
(
self
.
conv2
(
x
))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
-
1
,
4
*
4
*
50
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
def
train
(
model
,
device
,
train_loader
,
optimizer
):
model
.
train
()
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
train_loader
):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
optimizer
.
zero_grad
()
output
=
model
(
data
)
loss
=
F
.
nll_loss
(
output
,
target
)
loss
.
backward
()
optimizer
.
step
()
if
batch_idx
%
100
==
0
:
print
(
'{:2.0f}% Loss {}'
.
format
(
100
*
batch_idx
/
len
(
train_loader
),
loss
.
item
()))
def
test
(
model
,
device
,
test_loader
):
model
.
eval
()
test_loss
=
0
correct
=
0
with
torch
.
no_grad
():
for
data
,
target
in
test_loader
:
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
output
=
model
(
data
)
test_loss
+=
F
.
nll_loss
(
output
,
target
,
reduction
=
'sum'
).
item
()
pred
=
output
.
argmax
(
dim
=
1
,
keepdim
=
True
)
correct
+=
pred
.
eq
(
target
.
view_as
(
pred
)).
sum
().
item
()
test_loss
/=
len
(
test_loader
.
dataset
)
print
(
'Loss: {} Accuracy: {}%)
\n
'
.
format
(
test_loss
,
100
*
correct
/
len
(
test_loader
.
dataset
)))
def
main
():
torch
.
manual_seed
(
0
)
device
=
torch
.
device
(
'cpu'
)
trans
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
MNIST
(
'data'
,
train
=
True
,
download
=
True
,
transform
=
trans
),
batch_size
=
64
,
shuffle
=
True
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
MNIST
(
'data'
,
train
=
False
,
transform
=
trans
),
batch_size
=
1000
,
shuffle
=
True
)
model
=
Mnist
()
'''you can change this to DoReFaQuantizer to implement it
DoReFaQuantizer(configure_list).compress(model)
'''
configure_list
=
[{
'q_bits'
:
8
,
'op_types'
:[
'default'
]}]
quantizer
=
QAT_Quantizer
(
model
,
configure_list
)
quantizer
.
compress
()
# you can also use compress(model) method
# like thaht quantizer.compress(model)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
for
epoch
in
range
(
10
):
print
(
'# Epoch {} #'
.
format
(
epoch
))
train
(
model
,
device
,
train_loader
,
optimizer
)
test
(
model
,
device
,
test_loader
)
if
__name__
==
'__main__'
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment