Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
b7045b19
Unverified
Commit
b7045b19
authored
Feb 16, 2020
by
Cjkkkk
Committed by
GitHub
Feb 16, 2020
Browse files
fix buffer transfer bug (#2045)
parent
b8c0fb6e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
16 deletions
+27
-16
examples/model_compress/main_torch_pruner.py
examples/model_compress/main_torch_pruner.py
+4
-4
src/sdk/pynni/nni/compression/torch/compressor.py
src/sdk/pynni/nni/compression/torch/compressor.py
+22
-11
src/sdk/pynni/nni/compression/torch/pruners.py
src/sdk/pynni/nni/compression/torch/pruners.py
+1
-1
No files found.
examples/model_compress/main_torch_pruner.py
View file @
b7045b19
...
@@ -55,7 +55,7 @@ def test(model, device, test_loader):
...
@@ -55,7 +55,7 @@ def test(model, device, test_loader):
def
main
():
def
main
():
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
device
=
torch
.
device
(
'c
p
u'
)
device
=
torch
.
device
(
'cu
da
'
)
trans
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
trans
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
...
@@ -66,7 +66,7 @@ def main():
...
@@ -66,7 +66,7 @@ def main():
batch_size
=
1000
,
shuffle
=
True
)
batch_size
=
1000
,
shuffle
=
True
)
model
=
Mnist
()
model
=
Mnist
()
model
.
to
(
device
)
model
=
model
.
to
(
device
)
'''you can change this to LevelPruner to implement it
'''you can change this to LevelPruner to implement it
pruner = LevelPruner(configure_list)
pruner = LevelPruner(configure_list)
...
@@ -82,14 +82,14 @@ def main():
...
@@ -82,14 +82,14 @@ def main():
pruner
=
AGP_Pruner
(
model
,
configure_list
)
pruner
=
AGP_Pruner
(
model
,
configure_list
)
model
=
pruner
.
compress
()
model
=
pruner
.
compress
()
model
=
model
.
to
(
device
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
for
epoch
in
range
(
10
):
for
epoch
in
range
(
10
):
pruner
.
update_epoch
(
epoch
)
pruner
.
update_epoch
(
epoch
)
print
(
'# Epoch {} #'
.
format
(
epoch
))
print
(
'# Epoch {} #'
.
format
(
epoch
))
train
(
model
,
device
,
train_loader
,
optimizer
)
train
(
model
,
device
,
train_loader
,
optimizer
)
test
(
model
,
device
,
test_loader
)
test
(
model
,
device
,
test_loader
)
pruner
.
export_model
(
'model.pth'
,
'mask.pth'
,
'model.onnx'
,
[
1
,
1
,
28
,
28
])
pruner
.
export_model
(
'model.pth'
,
'mask.pth'
,
'model.onnx'
,
[
1
,
1
,
28
,
28
]
,
device
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
src/sdk/pynni/nni/compression/torch/compressor.py
View file @
b7045b19
...
@@ -226,7 +226,7 @@ class PrunerModuleWrapper(torch.nn.Module):
...
@@ -226,7 +226,7 @@ class PrunerModuleWrapper(torch.nn.Module):
# config and pruner
# config and pruner
self
.
config
=
config
self
.
config
=
config
self
.
pruner
=
pruner
self
.
pruner
=
pruner
self
.
registered_buffers
=
{}
self
.
registered_buffers
=
[]
# register buffer for mask
# register buffer for mask
self
.
register_buffer
(
"weight_mask"
,
torch
.
ones
(
self
.
module
.
weight
.
shape
))
self
.
register_buffer
(
"weight_mask"
,
torch
.
ones
(
self
.
module
.
weight
.
shape
))
...
@@ -234,16 +234,21 @@ class PrunerModuleWrapper(torch.nn.Module):
...
@@ -234,16 +234,21 @@ class PrunerModuleWrapper(torch.nn.Module):
self
.
register_buffer
(
"bias_mask"
,
torch
.
ones
(
self
.
module
.
bias
.
shape
))
self
.
register_buffer
(
"bias_mask"
,
torch
.
ones
(
self
.
module
.
bias
.
shape
))
else
:
else
:
self
.
register_buffer
(
"bias_mask"
,
None
)
self
.
register_buffer
(
"bias_mask"
,
None
)
self
.
registered_buffers
.
append
(
'weight_mask'
)
self
.
registered_buffers
[
'weight_mask'
]
=
self
.
weight_mask
self
.
registered_buffers
.
append
(
'bias_mask'
)
self
.
registered_buffers
[
'bias_mask'
]
=
self
.
bias_mask
# register user specified buffer
# register user specified buffer
for
name
in
self
.
pruner
.
buffers
:
for
name
in
self
.
pruner
.
buffers
:
self
.
register_buffer
(
name
,
self
.
pruner
.
buffers
[
name
].
clone
())
self
.
register_buffer
(
name
,
self
.
pruner
.
buffers
[
name
].
clone
())
self
.
registered_buffers
[
name
]
=
getattr
(
self
,
name
)
self
.
registered_buffers
.
append
(
name
)
def
get_registered_buffers
(
self
):
buffers
=
{}
for
name
in
self
.
registered_buffers
:
buffers
[
name
]
=
getattr
(
self
,
name
)
return
buffers
def
forward
(
self
,
*
inputs
):
def
forward
(
self
,
*
inputs
):
mask
=
self
.
pruner
.
calc_mask
(
LayerInfo
(
self
.
name
,
self
.
module
),
self
.
config
,
**
self
.
registered_buffers
)
mask
=
self
.
pruner
.
calc_mask
(
LayerInfo
(
self
.
name
,
self
.
module
),
self
.
config
,
**
self
.
get_
registered_buffers
()
)
if
mask
is
not
None
:
if
mask
is
not
None
:
self
.
weight_mask
.
copy_
(
mask
[
'weight'
])
self
.
weight_mask
.
copy_
(
mask
[
'weight'
])
# apply mask to weight
# apply mask to weight
...
@@ -399,6 +404,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -399,6 +404,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
# config and pruner
# config and pruner
self
.
config
=
config
self
.
config
=
config
self
.
quantizer
=
quantizer
self
.
quantizer
=
quantizer
self
.
registered_buffers
=
[]
# register buffer and parameter
# register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight
# old_weight is used to store origin weight and weight is used to store quantized weight
...
@@ -413,10 +419,15 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -413,10 +419,15 @@ class QuantizerModuleWrapper(torch.nn.Module):
self
.
module
.
register_buffer
(
'weight'
,
self
.
module
.
old_weight
)
self
.
module
.
register_buffer
(
'weight'
,
self
.
module
.
old_weight
)
# register user specified buffer
# register user specified buffer
self
.
registered_buffers
=
{}
for
name
in
self
.
quantizer
.
buffers
:
for
name
in
self
.
quantizer
.
buffers
:
self
.
register_buffer
(
name
,
self
.
quantizer
.
buffers
[
name
].
clone
())
self
.
register_buffer
(
name
,
self
.
quantizer
.
buffers
[
name
].
clone
())
self
.
registered_buffers
[
name
]
=
getattr
(
self
,
name
)
self
.
registered_buffers
.
append
(
name
)
def
get_registered_buffers
(
self
):
buffers
=
{}
for
name
in
self
.
registered_buffers
:
buffers
[
name
]
=
getattr
(
self
,
name
)
return
buffers
def
forward
(
self
,
*
inputs
):
def
forward
(
self
,
*
inputs
):
if
'input'
in
self
.
config
[
'quant_types'
]:
if
'input'
in
self
.
config
[
'quant_types'
]:
...
@@ -426,7 +437,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -426,7 +437,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self
.
quantizer
.
quantize_input
,
self
.
quantizer
.
quantize_input
,
self
.
config
,
self
.
config
,
LayerInfo
(
self
.
name
,
self
.
module
),
LayerInfo
(
self
.
name
,
self
.
module
),
**
self
.
registered_buffers
)
**
self
.
get_
registered_buffers
()
)
if
'weight'
in
self
.
config
[
'quant_types'
]
and
_check_weight
(
self
.
module
):
if
'weight'
in
self
.
config
[
'quant_types'
]
and
_check_weight
(
self
.
module
):
new_weight
=
self
.
quantizer
.
quant_grad
.
apply
(
new_weight
=
self
.
quantizer
.
quant_grad
.
apply
(
...
@@ -435,7 +446,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -435,7 +446,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self
.
quantizer
.
quantize_weight
,
self
.
quantizer
.
quantize_weight
,
self
.
config
,
self
.
config
,
LayerInfo
(
self
.
name
,
self
.
module
),
LayerInfo
(
self
.
name
,
self
.
module
),
**
self
.
registered_buffers
)
**
self
.
get_
registered_buffers
()
)
self
.
module
.
weight
=
new_weight
self
.
module
.
weight
=
new_weight
result
=
self
.
module
(
*
inputs
)
result
=
self
.
module
(
*
inputs
)
else
:
else
:
...
@@ -448,7 +459,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -448,7 +459,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self
.
quantizer
.
quantize_output
,
self
.
quantizer
.
quantize_output
,
self
.
config
,
self
.
config
,
LayerInfo
(
self
.
name
,
self
.
module
),
LayerInfo
(
self
.
name
,
self
.
module
),
**
self
.
registered_buffers
)
**
self
.
get_
registered_buffers
()
)
return
result
return
result
class
Quantizer
(
Compressor
):
class
Quantizer
(
Compressor
):
...
...
src/sdk/pynni/nni/compression/torch/pruners.py
View file @
b7045b19
...
@@ -170,7 +170,7 @@ class AGP_Pruner(Pruner):
...
@@ -170,7 +170,7 @@ class AGP_Pruner(Pruner):
if
epoch
>
0
:
if
epoch
>
0
:
self
.
now_epoch
=
epoch
self
.
now_epoch
=
epoch
for
wrapper
in
self
.
get_modules_wrapper
():
for
wrapper
in
self
.
get_modules_wrapper
():
wrapper
.
registered_buffers
[
'
if_calculated
'
]
.
copy_
(
torch
.
tensor
(
0
))
# pylint: disable=not-callable
wrapper
.
if_calculated
.
copy_
(
torch
.
tensor
(
0
))
# pylint: disable=not-callable
class
SlimPruner
(
Pruner
):
class
SlimPruner
(
Pruner
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment