Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
fb5ef932
Unverified
Commit
fb5ef932
authored
Feb 14, 2020
by
Cjkkkk
Committed by
GitHub
Feb 14, 2020
Browse files
fix model compression config validation (#2033)
parent
50e425f2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
16 deletions
+19
-16
src/sdk/pynni/nni/compression/torch/compressor.py
src/sdk/pynni/nni/compression/torch/compressor.py
+19
-16
No files found.
src/sdk/pynni/nni/compression/torch/compressor.py
View file @
fb5ef932
...
@@ -149,13 +149,24 @@ class Compressor:
...
@@ -149,13 +149,24 @@ class Compressor:
ret
=
None
ret
=
None
for
config
in
self
.
config_list
:
for
config
in
self
.
config_list
:
config
=
config
.
copy
()
config
=
config
.
copy
()
config
[
'op_types'
]
=
self
.
_expand_config_op_types
(
config
)
# expand config if key `default` is in config['op_types']
if
layer
.
type
not
in
config
[
'op_types'
]:
if
'op_types'
in
config
and
'default'
in
config
[
'op_types'
]:
expanded_op_types
=
[]
for
op_type
in
config
[
'op_types'
]:
if
op_type
==
'default'
:
expanded_op_types
.
extend
(
default_layers
.
weighted_modules
)
else
:
expanded_op_types
.
append
(
op_type
)
config
[
'op_types'
]
=
expanded_op_types
# check if condition is satisified
if
'op_types'
in
config
and
layer
.
type
not
in
config
[
'op_types'
]:
continue
continue
if
config
.
get
(
'op_names'
)
and
layer
.
name
not
in
config
[
'op_names'
]:
if
'op_names'
in
config
and
layer
.
name
not
in
config
[
'op_names'
]:
continue
continue
ret
=
config
ret
=
config
if
ret
is
None
or
ret
.
get
(
'exclude'
)
:
if
ret
is
None
or
'exclude'
in
ret
:
return
None
return
None
return
ret
return
ret
...
@@ -188,16 +199,6 @@ class Compressor:
...
@@ -188,16 +199,6 @@ class Compressor:
"""
"""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
_expand_config_op_types
(
self
,
config
):
if
config
is
None
:
return
[]
expanded_op_types
=
[]
for
op_type
in
config
.
get
(
'op_types'
,
[]):
if
op_type
==
'default'
:
expanded_op_types
.
extend
(
default_layers
.
weighted_modules
)
else
:
expanded_op_types
.
append
(
op_type
)
return
expanded_op_types
class
PrunerModuleWrapper
(
torch
.
nn
.
Module
):
class
PrunerModuleWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
pruner
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
pruner
):
...
@@ -229,11 +230,12 @@ class PrunerModuleWrapper(torch.nn.Module):
...
@@ -229,11 +230,12 @@ class PrunerModuleWrapper(torch.nn.Module):
# register buffer for mask
# register buffer for mask
self
.
register_buffer
(
"weight_mask"
,
torch
.
ones
(
self
.
module
.
weight
.
shape
))
self
.
register_buffer
(
"weight_mask"
,
torch
.
ones
(
self
.
module
.
weight
.
shape
))
self
.
registered_buffers
[
'weight_mask'
]
=
self
.
weight_mask
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
self
.
register_buffer
(
"bias_mask"
,
torch
.
ones
(
self
.
module
.
bias
.
shape
))
self
.
register_buffer
(
"bias_mask"
,
torch
.
ones
(
self
.
module
.
bias
.
shape
))
else
:
else
:
self
.
register_buffer
(
"bias_mask"
,
None
)
self
.
register_buffer
(
"bias_mask"
,
None
)
self
.
registered_buffers
[
'weight_mask'
]
=
self
.
weight_mask
self
.
registered_buffers
[
'bias_mask'
]
=
self
.
bias_mask
self
.
registered_buffers
[
'bias_mask'
]
=
self
.
bias_mask
# register user specified buffer
# register user specified buffer
for
name
in
self
.
pruner
.
buffers
:
for
name
in
self
.
pruner
.
buffers
:
...
@@ -297,7 +299,8 @@ class Pruner(Compressor):
...
@@ -297,7 +299,8 @@ class Pruner(Compressor):
"""
"""
_logger
.
info
(
"compressing module %s."
,
layer
.
name
)
_logger
.
info
(
"compressing module %s."
,
layer
.
name
)
wrapper
=
PrunerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
)
wrapper
=
PrunerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
)
assert
hasattr
(
layer
.
module
,
'weight'
)
assert
hasattr
(
layer
.
module
,
'weight'
),
"module %s does not have 'weight' attribute"
%
layer
.
name
# move newly registered buffers to the same device of weight
wrapper
.
to
(
layer
.
module
.
weight
.
device
)
wrapper
.
to
(
layer
.
module
.
weight
.
device
)
return
wrapper
return
wrapper
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment