Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
b7045b19
Unverified
Commit
b7045b19
authored
Feb 16, 2020
by
Cjkkkk
Committed by
GitHub
Feb 16, 2020
Browse files
fix buffer transfer bug (#2045)
parent
b8c0fb6e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
16 deletions
+27
-16
examples/model_compress/main_torch_pruner.py
examples/model_compress/main_torch_pruner.py
+4
-4
src/sdk/pynni/nni/compression/torch/compressor.py
src/sdk/pynni/nni/compression/torch/compressor.py
+22
-11
src/sdk/pynni/nni/compression/torch/pruners.py
src/sdk/pynni/nni/compression/torch/pruners.py
+1
-1
No files found.
examples/model_compress/main_torch_pruner.py
View file @
b7045b19
...
...
@@ -55,7 +55,7 @@ def test(model, device, test_loader):
def
main
():
torch
.
manual_seed
(
0
)
device
=
torch
.
device
(
'c
p
u'
)
device
=
torch
.
device
(
'cu
da
'
)
trans
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
...
...
@@ -66,7 +66,7 @@ def main():
batch_size
=
1000
,
shuffle
=
True
)
model
=
Mnist
()
model
.
to
(
device
)
model
=
model
.
to
(
device
)
'''you can change this to LevelPruner to implement it
pruner = LevelPruner(configure_list)
...
...
@@ -82,14 +82,14 @@ def main():
pruner
=
AGP_Pruner
(
model
,
configure_list
)
model
=
pruner
.
compress
()
model
=
model
.
to
(
device
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
for
epoch
in
range
(
10
):
pruner
.
update_epoch
(
epoch
)
print
(
'# Epoch {} #'
.
format
(
epoch
))
train
(
model
,
device
,
train_loader
,
optimizer
)
test
(
model
,
device
,
test_loader
)
pruner
.
export_model
(
'model.pth'
,
'mask.pth'
,
'model.onnx'
,
[
1
,
1
,
28
,
28
])
pruner
.
export_model
(
'model.pth'
,
'mask.pth'
,
'model.onnx'
,
[
1
,
1
,
28
,
28
]
,
device
)
if
__name__
==
'__main__'
:
...
...
src/sdk/pynni/nni/compression/torch/compressor.py
View file @
b7045b19
...
...
@@ -226,7 +226,7 @@ class PrunerModuleWrapper(torch.nn.Module):
# config and pruner
self
.
config
=
config
self
.
pruner
=
pruner
self
.
registered_buffers
=
{}
self
.
registered_buffers
=
[]
# register buffer for mask
self
.
register_buffer
(
"weight_mask"
,
torch
.
ones
(
self
.
module
.
weight
.
shape
))
...
...
@@ -234,16 +234,21 @@ class PrunerModuleWrapper(torch.nn.Module):
self
.
register_buffer
(
"bias_mask"
,
torch
.
ones
(
self
.
module
.
bias
.
shape
))
else
:
self
.
register_buffer
(
"bias_mask"
,
None
)
self
.
registered_buffers
[
'weight_mask'
]
=
self
.
weight_mask
self
.
registered_buffers
[
'bias_mask'
]
=
self
.
bias_mask
self
.
registered_buffers
.
append
(
'weight_mask'
)
self
.
registered_buffers
.
append
(
'bias_mask'
)
# register user specified buffer
for
name
in
self
.
pruner
.
buffers
:
self
.
register_buffer
(
name
,
self
.
pruner
.
buffers
[
name
].
clone
())
self
.
registered_buffers
[
name
]
=
getattr
(
self
,
name
)
self
.
registered_buffers
.
append
(
name
)
def
get_registered_buffers
(
self
):
buffers
=
{}
for
name
in
self
.
registered_buffers
:
buffers
[
name
]
=
getattr
(
self
,
name
)
return
buffers
def
forward
(
self
,
*
inputs
):
mask
=
self
.
pruner
.
calc_mask
(
LayerInfo
(
self
.
name
,
self
.
module
),
self
.
config
,
**
self
.
registered_buffers
)
mask
=
self
.
pruner
.
calc_mask
(
LayerInfo
(
self
.
name
,
self
.
module
),
self
.
config
,
**
self
.
get_
registered_buffers
()
)
if
mask
is
not
None
:
self
.
weight_mask
.
copy_
(
mask
[
'weight'
])
# apply mask to weight
...
...
@@ -399,6 +404,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
# config and pruner
self
.
config
=
config
self
.
quantizer
=
quantizer
self
.
registered_buffers
=
[]
# register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight
...
...
@@ -413,10 +419,15 @@ class QuantizerModuleWrapper(torch.nn.Module):
self
.
module
.
register_buffer
(
'weight'
,
self
.
module
.
old_weight
)
# register user specified buffer
self
.
registered_buffers
=
{}
for
name
in
self
.
quantizer
.
buffers
:
self
.
register_buffer
(
name
,
self
.
quantizer
.
buffers
[
name
].
clone
())
self
.
registered_buffers
[
name
]
=
getattr
(
self
,
name
)
self
.
registered_buffers
.
append
(
name
)
def
get_registered_buffers
(
self
):
buffers
=
{}
for
name
in
self
.
registered_buffers
:
buffers
[
name
]
=
getattr
(
self
,
name
)
return
buffers
def
forward
(
self
,
*
inputs
):
if
'input'
in
self
.
config
[
'quant_types'
]:
...
...
@@ -426,7 +437,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self
.
quantizer
.
quantize_input
,
self
.
config
,
LayerInfo
(
self
.
name
,
self
.
module
),
**
self
.
registered_buffers
)
**
self
.
get_
registered_buffers
()
)
if
'weight'
in
self
.
config
[
'quant_types'
]
and
_check_weight
(
self
.
module
):
new_weight
=
self
.
quantizer
.
quant_grad
.
apply
(
...
...
@@ -435,7 +446,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self
.
quantizer
.
quantize_weight
,
self
.
config
,
LayerInfo
(
self
.
name
,
self
.
module
),
**
self
.
registered_buffers
)
**
self
.
get_
registered_buffers
()
)
self
.
module
.
weight
=
new_weight
result
=
self
.
module
(
*
inputs
)
else
:
...
...
@@ -448,7 +459,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self
.
quantizer
.
quantize_output
,
self
.
config
,
LayerInfo
(
self
.
name
,
self
.
module
),
**
self
.
registered_buffers
)
**
self
.
get_
registered_buffers
()
)
return
result
class
Quantizer
(
Compressor
):
...
...
src/sdk/pynni/nni/compression/torch/pruners.py
View file @
b7045b19
...
...
@@ -170,7 +170,7 @@ class AGP_Pruner(Pruner):
if
epoch
>
0
:
self
.
now_epoch
=
epoch
for
wrapper
in
self
.
get_modules_wrapper
():
wrapper
.
registered_buffers
[
'
if_calculated
'
]
.
copy_
(
torch
.
tensor
(
0
))
# pylint: disable=not-callable
wrapper
.
if_calculated
.
copy_
(
torch
.
tensor
(
0
))
# pylint: disable=not-callable
class
SlimPruner
(
Pruner
):
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment