Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
fb5ef932
"...composable_kernel_onnxruntime.git" did not exist on "bf7e7d62a8a6753ac661879cdad061478498eea3"
Unverified
Commit
fb5ef932
authored
Feb 14, 2020
by
Cjkkkk
Committed by
GitHub
Feb 14, 2020
Browse files
fix model compression config validation (#2033)
parent
50e425f2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
16 deletions
+19
-16
src/sdk/pynni/nni/compression/torch/compressor.py
src/sdk/pynni/nni/compression/torch/compressor.py
+19
-16
No files found.
src/sdk/pynni/nni/compression/torch/compressor.py
View file @
fb5ef932
...
...
@@ -149,13 +149,24 @@ class Compressor:
ret
=
None
for
config
in
self
.
config_list
:
config
=
config
.
copy
()
config
[
'op_types'
]
=
self
.
_expand_config_op_types
(
config
)
if
layer
.
type
not
in
config
[
'op_types'
]:
# expand config if key `default` is in config['op_types']
if
'op_types'
in
config
and
'default'
in
config
[
'op_types'
]:
expanded_op_types
=
[]
for
op_type
in
config
[
'op_types'
]:
if
op_type
==
'default'
:
expanded_op_types
.
extend
(
default_layers
.
weighted_modules
)
else
:
expanded_op_types
.
append
(
op_type
)
config
[
'op_types'
]
=
expanded_op_types
# check if condition is satisified
if
'op_types'
in
config
and
layer
.
type
not
in
config
[
'op_types'
]:
continue
if
config
.
get
(
'op_names'
)
and
layer
.
name
not
in
config
[
'op_names'
]:
if
'op_names'
in
config
and
layer
.
name
not
in
config
[
'op_names'
]:
continue
ret
=
config
if
ret
is
None
or
ret
.
get
(
'exclude'
)
:
if
ret
is
None
or
'exclude'
in
ret
:
return
None
return
ret
...
...
@@ -188,16 +199,6 @@ class Compressor:
"""
raise
NotImplementedError
()
def
_expand_config_op_types
(
self
,
config
):
if
config
is
None
:
return
[]
expanded_op_types
=
[]
for
op_type
in
config
.
get
(
'op_types'
,
[]):
if
op_type
==
'default'
:
expanded_op_types
.
extend
(
default_layers
.
weighted_modules
)
else
:
expanded_op_types
.
append
(
op_type
)
return
expanded_op_types
class
PrunerModuleWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
pruner
):
...
...
@@ -229,11 +230,12 @@ class PrunerModuleWrapper(torch.nn.Module):
# register buffer for mask
self
.
register_buffer
(
"weight_mask"
,
torch
.
ones
(
self
.
module
.
weight
.
shape
))
self
.
registered_buffers
[
'weight_mask'
]
=
self
.
weight_mask
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
self
.
register_buffer
(
"bias_mask"
,
torch
.
ones
(
self
.
module
.
bias
.
shape
))
else
:
self
.
register_buffer
(
"bias_mask"
,
None
)
self
.
registered_buffers
[
'weight_mask'
]
=
self
.
weight_mask
self
.
registered_buffers
[
'bias_mask'
]
=
self
.
bias_mask
# register user specified buffer
for
name
in
self
.
pruner
.
buffers
:
...
...
@@ -297,7 +299,8 @@ class Pruner(Compressor):
"""
_logger
.
info
(
"compressing module %s."
,
layer
.
name
)
wrapper
=
PrunerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
)
assert
hasattr
(
layer
.
module
,
'weight'
)
assert
hasattr
(
layer
.
module
,
'weight'
),
"module %s does not have 'weight' attribute"
%
layer
.
name
# move newly registered buffers to the same device of weight
wrapper
.
to
(
layer
.
module
.
weight
.
device
)
return
wrapper
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment