Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
b7045b19
"vscode:/vscode.git/clone" did not exist on "c66b74720e355a4252e58cd586269fabe376cf89"
Unverified
Commit
b7045b19
authored
Feb 16, 2020
by
Cjkkkk
Committed by
GitHub
Feb 16, 2020
Browse files
fix buffer transfer bug (#2045)
parent
b8c0fb6e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
16 deletions
+27
-16
examples/model_compress/main_torch_pruner.py
examples/model_compress/main_torch_pruner.py
+4
-4
src/sdk/pynni/nni/compression/torch/compressor.py
src/sdk/pynni/nni/compression/torch/compressor.py
+22
-11
src/sdk/pynni/nni/compression/torch/pruners.py
src/sdk/pynni/nni/compression/torch/pruners.py
+1
-1
No files found.
examples/model_compress/main_torch_pruner.py
View file @
b7045b19
...
@@ -55,7 +55,7 @@ def test(model, device, test_loader):
...
@@ -55,7 +55,7 @@ def test(model, device, test_loader):
def
main
():
def
main
():
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
device
=
torch
.
device
(
'c
p
u'
)
device
=
torch
.
device
(
'cu
da
'
)
trans
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
trans
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
...
@@ -66,7 +66,7 @@ def main():
...
@@ -66,7 +66,7 @@ def main():
batch_size
=
1000
,
shuffle
=
True
)
batch_size
=
1000
,
shuffle
=
True
)
model
=
Mnist
()
model
=
Mnist
()
model
.
to
(
device
)
model
=
model
.
to
(
device
)
'''you can change this to LevelPruner to implement it
'''you can change this to LevelPruner to implement it
pruner = LevelPruner(configure_list)
pruner = LevelPruner(configure_list)
...
@@ -82,14 +82,14 @@ def main():
...
@@ -82,14 +82,14 @@ def main():
pruner
=
AGP_Pruner
(
model
,
configure_list
)
pruner
=
AGP_Pruner
(
model
,
configure_list
)
model
=
pruner
.
compress
()
model
=
pruner
.
compress
()
model
=
model
.
to
(
device
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
for
epoch
in
range
(
10
):
for
epoch
in
range
(
10
):
pruner
.
update_epoch
(
epoch
)
pruner
.
update_epoch
(
epoch
)
print
(
'# Epoch {} #'
.
format
(
epoch
))
print
(
'# Epoch {} #'
.
format
(
epoch
))
train
(
model
,
device
,
train_loader
,
optimizer
)
train
(
model
,
device
,
train_loader
,
optimizer
)
test
(
model
,
device
,
test_loader
)
test
(
model
,
device
,
test_loader
)
pruner
.
export_model
(
'model.pth'
,
'mask.pth'
,
'model.onnx'
,
[
1
,
1
,
28
,
28
])
pruner
.
export_model
(
'model.pth'
,
'mask.pth'
,
'model.onnx'
,
[
1
,
1
,
28
,
28
]
,
device
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
src/sdk/pynni/nni/compression/torch/compressor.py
View file @
b7045b19
...
@@ -226,7 +226,7 @@ class PrunerModuleWrapper(torch.nn.Module):
...
@@ -226,7 +226,7 @@ class PrunerModuleWrapper(torch.nn.Module):
# config and pruner
# config and pruner
self
.
config
=
config
self
.
config
=
config
self
.
pruner
=
pruner
self
.
pruner
=
pruner
self
.
registered_buffers
=
{}
self
.
registered_buffers
=
[]
# register buffer for mask
# register buffer for mask
self
.
register_buffer
(
"weight_mask"
,
torch
.
ones
(
self
.
module
.
weight
.
shape
))
self
.
register_buffer
(
"weight_mask"
,
torch
.
ones
(
self
.
module
.
weight
.
shape
))
...
@@ -234,16 +234,21 @@ class PrunerModuleWrapper(torch.nn.Module):
...
@@ -234,16 +234,21 @@ class PrunerModuleWrapper(torch.nn.Module):
self
.
register_buffer
(
"bias_mask"
,
torch
.
ones
(
self
.
module
.
bias
.
shape
))
self
.
register_buffer
(
"bias_mask"
,
torch
.
ones
(
self
.
module
.
bias
.
shape
))
else
:
else
:
self
.
register_buffer
(
"bias_mask"
,
None
)
self
.
register_buffer
(
"bias_mask"
,
None
)
self
.
registered_buffers
.
append
(
'weight_mask'
)
self
.
registered_buffers
[
'weight_mask'
]
=
self
.
weight_mask
self
.
registered_buffers
.
append
(
'bias_mask'
)
self
.
registered_buffers
[
'bias_mask'
]
=
self
.
bias_mask
# register user specified buffer
# register user specified buffer
for
name
in
self
.
pruner
.
buffers
:
for
name
in
self
.
pruner
.
buffers
:
self
.
register_buffer
(
name
,
self
.
pruner
.
buffers
[
name
].
clone
())
self
.
register_buffer
(
name
,
self
.
pruner
.
buffers
[
name
].
clone
())
self
.
registered_buffers
[
name
]
=
getattr
(
self
,
name
)
self
.
registered_buffers
.
append
(
name
)
def
get_registered_buffers
(
self
):
buffers
=
{}
for
name
in
self
.
registered_buffers
:
buffers
[
name
]
=
getattr
(
self
,
name
)
return
buffers
def
forward
(
self
,
*
inputs
):
def
forward
(
self
,
*
inputs
):
mask
=
self
.
pruner
.
calc_mask
(
LayerInfo
(
self
.
name
,
self
.
module
),
self
.
config
,
**
self
.
registered_buffers
)
mask
=
self
.
pruner
.
calc_mask
(
LayerInfo
(
self
.
name
,
self
.
module
),
self
.
config
,
**
self
.
get_
registered_buffers
()
)
if
mask
is
not
None
:
if
mask
is
not
None
:
self
.
weight_mask
.
copy_
(
mask
[
'weight'
])
self
.
weight_mask
.
copy_
(
mask
[
'weight'
])
# apply mask to weight
# apply mask to weight
...
@@ -399,6 +404,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -399,6 +404,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
# config and pruner
# config and pruner
self
.
config
=
config
self
.
config
=
config
self
.
quantizer
=
quantizer
self
.
quantizer
=
quantizer
self
.
registered_buffers
=
[]
# register buffer and parameter
# register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight
# old_weight is used to store origin weight and weight is used to store quantized weight
...
@@ -413,10 +419,15 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -413,10 +419,15 @@ class QuantizerModuleWrapper(torch.nn.Module):
self
.
module
.
register_buffer
(
'weight'
,
self
.
module
.
old_weight
)
self
.
module
.
register_buffer
(
'weight'
,
self
.
module
.
old_weight
)
# register user specified buffer
# register user specified buffer
self
.
registered_buffers
=
{}
for
name
in
self
.
quantizer
.
buffers
:
for
name
in
self
.
quantizer
.
buffers
:
self
.
register_buffer
(
name
,
self
.
quantizer
.
buffers
[
name
].
clone
())
self
.
register_buffer
(
name
,
self
.
quantizer
.
buffers
[
name
].
clone
())
self
.
registered_buffers
[
name
]
=
getattr
(
self
,
name
)
self
.
registered_buffers
.
append
(
name
)
def
get_registered_buffers
(
self
):
buffers
=
{}
for
name
in
self
.
registered_buffers
:
buffers
[
name
]
=
getattr
(
self
,
name
)
return
buffers
def
forward
(
self
,
*
inputs
):
def
forward
(
self
,
*
inputs
):
if
'input'
in
self
.
config
[
'quant_types'
]:
if
'input'
in
self
.
config
[
'quant_types'
]:
...
@@ -426,7 +437,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -426,7 +437,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self
.
quantizer
.
quantize_input
,
self
.
quantizer
.
quantize_input
,
self
.
config
,
self
.
config
,
LayerInfo
(
self
.
name
,
self
.
module
),
LayerInfo
(
self
.
name
,
self
.
module
),
**
self
.
registered_buffers
)
**
self
.
get_
registered_buffers
()
)
if
'weight'
in
self
.
config
[
'quant_types'
]
and
_check_weight
(
self
.
module
):
if
'weight'
in
self
.
config
[
'quant_types'
]
and
_check_weight
(
self
.
module
):
new_weight
=
self
.
quantizer
.
quant_grad
.
apply
(
new_weight
=
self
.
quantizer
.
quant_grad
.
apply
(
...
@@ -435,7 +446,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -435,7 +446,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self
.
quantizer
.
quantize_weight
,
self
.
quantizer
.
quantize_weight
,
self
.
config
,
self
.
config
,
LayerInfo
(
self
.
name
,
self
.
module
),
LayerInfo
(
self
.
name
,
self
.
module
),
**
self
.
registered_buffers
)
**
self
.
get_
registered_buffers
()
)
self
.
module
.
weight
=
new_weight
self
.
module
.
weight
=
new_weight
result
=
self
.
module
(
*
inputs
)
result
=
self
.
module
(
*
inputs
)
else
:
else
:
...
@@ -448,7 +459,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -448,7 +459,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self
.
quantizer
.
quantize_output
,
self
.
quantizer
.
quantize_output
,
self
.
config
,
self
.
config
,
LayerInfo
(
self
.
name
,
self
.
module
),
LayerInfo
(
self
.
name
,
self
.
module
),
**
self
.
registered_buffers
)
**
self
.
get_
registered_buffers
()
)
return
result
return
result
class
Quantizer
(
Compressor
):
class
Quantizer
(
Compressor
):
...
...
src/sdk/pynni/nni/compression/torch/pruners.py
View file @
b7045b19
...
@@ -170,7 +170,7 @@ class AGP_Pruner(Pruner):
...
@@ -170,7 +170,7 @@ class AGP_Pruner(Pruner):
if
epoch
>
0
:
if
epoch
>
0
:
self
.
now_epoch
=
epoch
self
.
now_epoch
=
epoch
for
wrapper
in
self
.
get_modules_wrapper
():
for
wrapper
in
self
.
get_modules_wrapper
():
wrapper
.
registered_buffers
[
'
if_calculated
'
]
.
copy_
(
torch
.
tensor
(
0
))
# pylint: disable=not-callable
wrapper
.
if_calculated
.
copy_
(
torch
.
tensor
(
0
))
# pylint: disable=not-callable
class
SlimPruner
(
Pruner
):
class
SlimPruner
(
Pruner
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment