Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
43de0118
Unverified
Commit
43de0118
authored
Feb 17, 2020
by
QuanluZhang
Committed by
GitHub
Feb 17, 2020
Browse files
compression speedup: small code refactor (#2065)
parent
e6cedb89
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
26 deletions
+24
-26
src/sdk/pynni/nni/compression/speedup/torch/compress_modules.py
...k/pynni/nni/compression/speedup/torch/compress_modules.py
+7
-4
src/sdk/pynni/nni/compression/speedup/torch/compressor.py
src/sdk/pynni/nni/compression/speedup/torch/compressor.py
+17
-22
No files found.
src/sdk/pynni/nni/compression/speedup/torch/compress_modules.py
View file @
43de0118
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
logging
import
torch
import
torch
from
.infer_shape
import
ModuleMasks
from
.infer_shape
import
ModuleMasks
_logger
=
logging
.
getLogger
(
__name__
)
replace_module
=
{
replace_module
=
{
'BatchNorm2d'
:
lambda
module
,
mask
:
replace_batchnorm2d
(
module
,
mask
),
'BatchNorm2d'
:
lambda
module
,
mask
:
replace_batchnorm2d
(
module
,
mask
),
'Conv2d'
:
lambda
module
,
mask
:
replace_conv2d
(
module
,
mask
),
'Conv2d'
:
lambda
module
,
mask
:
replace_conv2d
(
module
,
mask
),
...
@@ -16,6 +19,7 @@ def no_replace(module, mask):
...
@@ -16,6 +19,7 @@ def no_replace(module, mask):
"""
"""
No need to replace
No need to replace
"""
"""
_logger
.
debug
(
"no need to replace"
)
return
module
return
module
def
replace_linear
(
linear
,
mask
):
def
replace_linear
(
linear
,
mask
):
...
@@ -37,9 +41,8 @@ def replace_linear(linear, mask):
...
@@ -37,9 +41,8 @@ def replace_linear(linear, mask):
assert
mask
.
output_mask
is
None
assert
mask
.
output_mask
is
None
assert
not
mask
.
param_masks
assert
not
mask
.
param_masks
index
=
mask
.
input_mask
.
mask_index
[
-
1
]
index
=
mask
.
input_mask
.
mask_index
[
-
1
]
print
(
mask
.
input_mask
.
mask_index
)
in_features
=
index
.
size
()[
0
]
in_features
=
index
.
size
()[
0
]
print
(
'linear: '
,
in_features
)
_logger
.
debug
(
"replace linear with new in_features: %d"
,
in_features
)
new_linear
=
torch
.
nn
.
Linear
(
in_features
=
in_features
,
new_linear
=
torch
.
nn
.
Linear
(
in_features
=
in_features
,
out_features
=
linear
.
out_features
,
out_features
=
linear
.
out_features
,
bias
=
linear
.
bias
is
not
None
)
bias
=
linear
.
bias
is
not
None
)
...
@@ -67,7 +70,7 @@ def replace_batchnorm2d(norm, mask):
...
@@ -67,7 +70,7 @@ def replace_batchnorm2d(norm, mask):
assert
'weight'
in
mask
.
param_masks
and
'bias'
in
mask
.
param_masks
assert
'weight'
in
mask
.
param_masks
and
'bias'
in
mask
.
param_masks
index
=
mask
.
param_masks
[
'weight'
].
mask_index
[
0
]
index
=
mask
.
param_masks
[
'weight'
].
mask_index
[
0
]
num_features
=
index
.
size
()[
0
]
num_features
=
index
.
size
()[
0
]
print
(
"replace batchnorm2d
: "
,
num_features
,
index
)
_logger
.
debug
(
"replace batchnorm2d
with
num_features
: %d"
,
num_features
)
new_norm
=
torch
.
nn
.
BatchNorm2d
(
num_features
=
num_features
,
new_norm
=
torch
.
nn
.
BatchNorm2d
(
num_features
=
num_features
,
eps
=
norm
.
eps
,
eps
=
norm
.
eps
,
momentum
=
norm
.
momentum
,
momentum
=
norm
.
momentum
,
...
@@ -106,6 +109,7 @@ def replace_conv2d(conv, mask):
...
@@ -106,6 +109,7 @@ def replace_conv2d(conv, mask):
else
:
else
:
out_channels_index
=
mask
.
output_mask
.
mask_index
[
1
]
out_channels_index
=
mask
.
output_mask
.
mask_index
[
1
]
out_channels
=
out_channels_index
.
size
()[
0
]
out_channels
=
out_channels_index
.
size
()[
0
]
_logger
.
debug
(
"replace conv2d with in_channels: %d, out_channels: %d"
,
in_channels
,
out_channels
)
new_conv
=
torch
.
nn
.
Conv2d
(
in_channels
=
in_channels
,
new_conv
=
torch
.
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
kernel_size
=
conv
.
kernel_size
,
kernel_size
=
conv
.
kernel_size
,
...
@@ -128,6 +132,5 @@ def replace_conv2d(conv, mask):
...
@@ -128,6 +132,5 @@ def replace_conv2d(conv, mask):
assert
tmp_weight_data
is
not
None
,
"Conv2d weight should be updated based on masks"
assert
tmp_weight_data
is
not
None
,
"Conv2d weight should be updated based on masks"
new_conv
.
weight
.
data
.
copy_
(
tmp_weight_data
)
new_conv
.
weight
.
data
.
copy_
(
tmp_weight_data
)
if
conv
.
bias
is
not
None
:
if
conv
.
bias
is
not
None
:
print
(
'final conv.bias is not None'
)
new_conv
.
bias
.
data
.
copy_
(
conv
.
bias
.
data
if
tmp_bias_data
is
None
else
tmp_bias_data
)
new_conv
.
bias
.
data
.
copy_
(
conv
.
bias
.
data
if
tmp_bias_data
is
None
else
tmp_bias_data
)
return
new_conv
return
new_conv
src/sdk/pynni/nni/compression/speedup/torch/compressor.py
View file @
43de0118
...
@@ -158,7 +158,7 @@ class ModelSpeedup:
...
@@ -158,7 +158,7 @@ class ModelSpeedup:
"""
"""
# TODO: scope name could be empty
# TODO: scope name could be empty
node_name
=
'.'
.
join
([
node
.
scopeName
(),
node
.
kind
(),
str
(
self
.
global_count
)])
node_name
=
'.'
.
join
([
node
.
scopeName
(),
node
.
kind
(),
str
(
self
.
global_count
)])
#print('
node
_
name:
'
, node_name)
_logger
.
debug
(
"expand non-prim node,
node
name:
%s"
,
node_name
)
self
.
global_count
+=
1
self
.
global_count
+=
1
op_type
=
node
.
kind
()
op_type
=
node
.
kind
()
...
@@ -173,7 +173,6 @@ class ModelSpeedup:
...
@@ -173,7 +173,6 @@ class ModelSpeedup:
input_name
=
_input
.
debugName
()
input_name
=
_input
.
debugName
()
if
input_name
in
output_to_node
and
output_to_node
[
input_name
]
in
nodes
:
if
input_name
in
output_to_node
and
output_to_node
[
input_name
]
in
nodes
:
predecessor_node
=
output_to_node
[
input_name
]
predecessor_node
=
output_to_node
[
input_name
]
#print("predecessor_node: ", predecessor_node)
if
predecessor_node
.
kind
().
startswith
(
'prim::'
):
if
predecessor_node
.
kind
().
startswith
(
'prim::'
):
node_group
.
append
(
predecessor_node
)
node_group
.
append
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
...
@@ -231,7 +230,7 @@ class ModelSpeedup:
...
@@ -231,7 +230,7 @@ class ModelSpeedup:
"""
"""
graph
=
self
.
trace_graph
.
graph
graph
=
self
.
trace_graph
.
graph
# if torch 1.4.0 is used, consider run torch._C._jit_pass_inline(graph) here
# if torch 1.4.0 is used, consider run torch._C._jit_pass_inline(graph) here
#
print
(graph)
#
_logger.debug
(graph)
# build output mapping, from output debugName to its node
# build output mapping, from output debugName to its node
output_to_node
=
dict
()
output_to_node
=
dict
()
# build input mapping, from input debugName to its node
# build input mapping, from input debugName to its node
...
@@ -301,10 +300,8 @@ class ModelSpeedup:
...
@@ -301,10 +300,8 @@ class ModelSpeedup:
m_inputs
.
append
(
_input
)
m_inputs
.
append
(
_input
)
elif
not
output_to_node
[
_input
]
in
nodes
:
elif
not
output_to_node
[
_input
]
in
nodes
:
m_inputs
.
append
(
_input
)
m_inputs
.
append
(
_input
)
print
(
"module node_name: "
,
module_name
)
if
module_name
==
''
:
if
module_name
==
''
:
for
n
in
nodes
:
_logger
.
warning
(
"module_name is empty string"
)
print
(
n
)
g_node
=
GNode
(
module_name
,
'module'
,
module_to_type
[
module_name
],
m_inputs
,
m_outputs
,
nodes
)
g_node
=
GNode
(
module_name
,
'module'
,
module_to_type
[
module_name
],
m_inputs
,
m_outputs
,
nodes
)
self
.
g_nodes
.
append
(
g_node
)
self
.
g_nodes
.
append
(
g_node
)
...
@@ -345,10 +342,7 @@ class ModelSpeedup:
...
@@ -345,10 +342,7 @@ class ModelSpeedup:
predecessors
=
[]
predecessors
=
[]
for
_input
in
self
.
name_to_gnode
[
module_name
].
inputs
:
for
_input
in
self
.
name_to_gnode
[
module_name
].
inputs
:
if
not
_input
in
self
.
output_to_gnode
:
if
not
_input
in
self
.
output_to_gnode
:
print
(
_input
)
_logger
.
debug
(
"cannot find gnode with %s as its output"
,
_input
)
if
not
_input
in
self
.
output_to_gnode
:
# TODO: check _input which does not have node
print
(
"output with no gnode: "
,
_input
)
else
:
else
:
g_node
=
self
.
output_to_gnode
[
_input
]
g_node
=
self
.
output_to_gnode
[
_input
]
predecessors
.
append
(
g_node
.
name
)
predecessors
.
append
(
g_node
.
name
)
...
@@ -407,15 +401,15 @@ class ModelSpeedup:
...
@@ -407,15 +401,15 @@ class ModelSpeedup:
self
.
inferred_masks
[
module_name
]
=
module_masks
self
.
inferred_masks
[
module_name
]
=
module_masks
m_type
=
self
.
name_to_gnode
[
module_name
].
op_type
m_type
=
self
.
name_to_gnode
[
module_name
].
op_type
print
(
"infer_module_mask: {}, module type: {}"
.
format
(
module_name
,
m_type
)
)
_logger
.
debug
(
"infer mask of module %s with op_type %s"
,
module_name
,
m_type
)
if
mask
is
not
None
:
if
mask
is
not
None
:
#print
("mask is not None")
_logger
.
debug
(
"mask is not None"
)
if
not
m_type
in
infer_from_mask
:
if
not
m_type
in
infer_from_mask
:
raise
RuntimeError
(
"Has not supported infering
\
raise
RuntimeError
(
"Has not supported infering
\
input/output shape from mask for module/function: `{}`"
.
format
(
m_type
))
input/output shape from mask for module/function: `{}`"
.
format
(
m_type
))
input_cmask
,
output_cmask
=
infer_from_mask
[
m_type
](
module_masks
,
mask
)
input_cmask
,
output_cmask
=
infer_from_mask
[
m_type
](
module_masks
,
mask
)
if
in_shape
is
not
None
:
if
in_shape
is
not
None
:
#print
("in_shape is not None")
_logger
.
debug
(
"in_shape is not None"
)
if
not
m_type
in
infer_from_inshape
:
if
not
m_type
in
infer_from_inshape
:
raise
RuntimeError
(
"Has not supported infering
\
raise
RuntimeError
(
"Has not supported infering
\
output shape from input shape for module/function: `{}`"
.
format
(
m_type
))
output shape from input shape for module/function: `{}`"
.
format
(
m_type
))
...
@@ -426,23 +420,19 @@ class ModelSpeedup:
...
@@ -426,23 +420,19 @@ class ModelSpeedup:
else
:
else
:
output_cmask
=
infer_from_inshape
[
m_type
](
module_masks
,
in_shape
)
output_cmask
=
infer_from_inshape
[
m_type
](
module_masks
,
in_shape
)
if
out_shape
is
not
None
:
if
out_shape
is
not
None
:
#print
("out_shape is not None")
_logger
.
debug
(
"out_shape is not None"
)
if
not
m_type
in
infer_from_outshape
:
if
not
m_type
in
infer_from_outshape
:
raise
RuntimeError
(
"Has not supported infering
\
raise
RuntimeError
(
"Has not supported infering
\
input shape from output shape for module/function: `{}`"
.
format
(
m_type
))
input shape from output shape for module/function: `{}`"
.
format
(
m_type
))
input_cmask
=
infer_from_outshape
[
m_type
](
module_masks
,
out_shape
)
input_cmask
=
infer_from_outshape
[
m_type
](
module_masks
,
out_shape
)
if
input_cmask
:
if
input_cmask
:
#print("input_cmask is not None")
predecessors
=
self
.
_find_predecessors
(
module_name
)
predecessors
=
self
.
_find_predecessors
(
module_name
)
for
_module_name
in
predecessors
:
for
_module_name
in
predecessors
:
print
(
"input_cmask, module_name: "
,
_module_name
)
self
.
infer_module_mask
(
_module_name
,
out_shape
=
input_cmask
)
self
.
infer_module_mask
(
_module_name
,
out_shape
=
input_cmask
)
if
output_cmask
:
if
output_cmask
:
#print("output_cmask is not None")
successors
=
self
.
_find_successors
(
module_name
)
successors
=
self
.
_find_successors
(
module_name
)
for
_module_name
in
successors
:
for
_module_name
in
successors
:
print
(
"output_cmask, module_name: "
,
_module_name
)
self
.
infer_module_mask
(
_module_name
,
in_shape
=
output_cmask
)
self
.
infer_module_mask
(
_module_name
,
in_shape
=
output_cmask
)
def
infer_modules_masks
(
self
):
def
infer_modules_masks
(
self
):
...
@@ -463,16 +453,19 @@ class ModelSpeedup:
...
@@ -463,16 +453,19 @@ class ModelSpeedup:
"""
"""
for
module_name
in
self
.
inferred_masks
:
for
module_name
in
self
.
inferred_masks
:
g_node
=
self
.
name_to_gnode
[
module_name
]
g_node
=
self
.
name_to_gnode
[
module_name
]
print
(
module_name
,
g_node
.
op_type
)
_logger
.
debug
(
"replace %s, in %s type, with op_type %s"
,
module_name
,
g_node
.
type
,
g_node
.
op_type
)
if
g_node
.
type
==
'module'
:
if
g_node
.
type
==
'module'
:
super_module
,
leaf_module
=
get_module_by_name
(
self
.
bound_model
,
module_name
)
super_module
,
leaf_module
=
get_module_by_name
(
self
.
bound_model
,
module_name
)
m_type
=
g_node
.
op_type
m_type
=
g_node
.
op_type
if
not
m_type
in
replace_module
:
if
not
m_type
in
replace_module
:
raise
RuntimeError
(
"Has not supported replacing the module: `{}`"
.
format
(
m_type
))
raise
RuntimeError
(
"Has not supported replacing the module: `{}`"
.
format
(
m_type
))
_logger
.
info
(
"replace module (name: %s, op_type: %s)"
,
module_name
,
m_type
)
compressed_module
=
replace_module
[
m_type
](
leaf_module
,
self
.
inferred_masks
[
module_name
])
compressed_module
=
replace_module
[
m_type
](
leaf_module
,
self
.
inferred_masks
[
module_name
])
setattr
(
super_module
,
module_name
.
split
(
'.'
)[
-
1
],
compressed_module
)
setattr
(
super_module
,
module_name
.
split
(
'.'
)[
-
1
],
compressed_module
)
elif
g_node
.
type
==
'func'
:
elif
g_node
.
type
==
'func'
:
print
(
"Warning: Cannot replace func..."
)
_logger
.
info
(
"Warning: cannot replace (name: %s, op_type: %s) which is func type"
,
module_name
,
g_node
.
op_type
)
else
:
else
:
raise
RuntimeError
(
"Unsupported GNode type: {}"
.
format
(
g_node
.
type
))
raise
RuntimeError
(
"Unsupported GNode type: {}"
.
format
(
g_node
.
type
))
...
@@ -482,10 +475,12 @@ class ModelSpeedup:
...
@@ -482,10 +475,12 @@ class ModelSpeedup:
first, do mask/shape inference,
first, do mask/shape inference,
second, replace modules
second, replace modules
"""
"""
#print("start to compress")
_logger
.
info
(
"start to speed up the model"
)
_logger
.
info
(
"infer module masks..."
)
self
.
infer_modules_masks
()
self
.
infer_modules_masks
()
_logger
.
info
(
"replace compressed modules..."
)
self
.
replace_compressed_modules
()
self
.
replace_compressed_modules
()
#print("finished compressing
")
_logger
.
info
(
"speedup done
"
)
# resume the model mode to that before the model is speed up
# resume the model mode to that before the model is speed up
if
self
.
is_training
:
if
self
.
is_training
:
self
.
bound_model
.
train
()
self
.
bound_model
.
train
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment