Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
43de0118
Unverified
Commit
43de0118
authored
Feb 17, 2020
by
QuanluZhang
Committed by
GitHub
Feb 17, 2020
Browse files
compression speedup: small code refactor (#2065)
parent
e6cedb89
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
26 deletions
+24
-26
src/sdk/pynni/nni/compression/speedup/torch/compress_modules.py
...k/pynni/nni/compression/speedup/torch/compress_modules.py
+7
-4
src/sdk/pynni/nni/compression/speedup/torch/compressor.py
src/sdk/pynni/nni/compression/speedup/torch/compressor.py
+17
-22
No files found.
src/sdk/pynni/nni/compression/speedup/torch/compress_modules.py
View file @
43de0118
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
logging
import
torch
import
torch
from
.infer_shape
import
ModuleMasks
from
.infer_shape
import
ModuleMasks
_logger
=
logging
.
getLogger
(
__name__
)
replace_module
=
{
replace_module
=
{
'BatchNorm2d'
:
lambda
module
,
mask
:
replace_batchnorm2d
(
module
,
mask
),
'BatchNorm2d'
:
lambda
module
,
mask
:
replace_batchnorm2d
(
module
,
mask
),
'Conv2d'
:
lambda
module
,
mask
:
replace_conv2d
(
module
,
mask
),
'Conv2d'
:
lambda
module
,
mask
:
replace_conv2d
(
module
,
mask
),
...
@@ -16,6 +19,7 @@ def no_replace(module, mask):
...
@@ -16,6 +19,7 @@ def no_replace(module, mask):
"""
"""
No need to replace
No need to replace
"""
"""
_logger
.
debug
(
"no need to replace"
)
return
module
return
module
def
replace_linear
(
linear
,
mask
):
def
replace_linear
(
linear
,
mask
):
...
@@ -37,9 +41,8 @@ def replace_linear(linear, mask):
...
@@ -37,9 +41,8 @@ def replace_linear(linear, mask):
assert
mask
.
output_mask
is
None
assert
mask
.
output_mask
is
None
assert
not
mask
.
param_masks
assert
not
mask
.
param_masks
index
=
mask
.
input_mask
.
mask_index
[
-
1
]
index
=
mask
.
input_mask
.
mask_index
[
-
1
]
print
(
mask
.
input_mask
.
mask_index
)
in_features
=
index
.
size
()[
0
]
in_features
=
index
.
size
()[
0
]
print
(
'linear: '
,
in_features
)
_logger
.
debug
(
"replace linear with new in_features: %d"
,
in_features
)
new_linear
=
torch
.
nn
.
Linear
(
in_features
=
in_features
,
new_linear
=
torch
.
nn
.
Linear
(
in_features
=
in_features
,
out_features
=
linear
.
out_features
,
out_features
=
linear
.
out_features
,
bias
=
linear
.
bias
is
not
None
)
bias
=
linear
.
bias
is
not
None
)
...
@@ -67,7 +70,7 @@ def replace_batchnorm2d(norm, mask):
...
@@ -67,7 +70,7 @@ def replace_batchnorm2d(norm, mask):
assert
'weight'
in
mask
.
param_masks
and
'bias'
in
mask
.
param_masks
assert
'weight'
in
mask
.
param_masks
and
'bias'
in
mask
.
param_masks
index
=
mask
.
param_masks
[
'weight'
].
mask_index
[
0
]
index
=
mask
.
param_masks
[
'weight'
].
mask_index
[
0
]
num_features
=
index
.
size
()[
0
]
num_features
=
index
.
size
()[
0
]
print
(
"replace batchnorm2d
: "
,
num_features
,
index
)
_logger
.
debug
(
"replace batchnorm2d
with
num_features
: %d"
,
num_features
)
new_norm
=
torch
.
nn
.
BatchNorm2d
(
num_features
=
num_features
,
new_norm
=
torch
.
nn
.
BatchNorm2d
(
num_features
=
num_features
,
eps
=
norm
.
eps
,
eps
=
norm
.
eps
,
momentum
=
norm
.
momentum
,
momentum
=
norm
.
momentum
,
...
@@ -106,6 +109,7 @@ def replace_conv2d(conv, mask):
...
@@ -106,6 +109,7 @@ def replace_conv2d(conv, mask):
else
:
else
:
out_channels_index
=
mask
.
output_mask
.
mask_index
[
1
]
out_channels_index
=
mask
.
output_mask
.
mask_index
[
1
]
out_channels
=
out_channels_index
.
size
()[
0
]
out_channels
=
out_channels_index
.
size
()[
0
]
_logger
.
debug
(
"replace conv2d with in_channels: %d, out_channels: %d"
,
in_channels
,
out_channels
)
new_conv
=
torch
.
nn
.
Conv2d
(
in_channels
=
in_channels
,
new_conv
=
torch
.
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
kernel_size
=
conv
.
kernel_size
,
kernel_size
=
conv
.
kernel_size
,
...
@@ -128,6 +132,5 @@ def replace_conv2d(conv, mask):
...
@@ -128,6 +132,5 @@ def replace_conv2d(conv, mask):
assert
tmp_weight_data
is
not
None
,
"Conv2d weight should be updated based on masks"
assert
tmp_weight_data
is
not
None
,
"Conv2d weight should be updated based on masks"
new_conv
.
weight
.
data
.
copy_
(
tmp_weight_data
)
new_conv
.
weight
.
data
.
copy_
(
tmp_weight_data
)
if
conv
.
bias
is
not
None
:
if
conv
.
bias
is
not
None
:
print
(
'final conv.bias is not None'
)
new_conv
.
bias
.
data
.
copy_
(
conv
.
bias
.
data
if
tmp_bias_data
is
None
else
tmp_bias_data
)
new_conv
.
bias
.
data
.
copy_
(
conv
.
bias
.
data
if
tmp_bias_data
is
None
else
tmp_bias_data
)
return
new_conv
return
new_conv
src/sdk/pynni/nni/compression/speedup/torch/compressor.py
View file @
43de0118
...
@@ -158,7 +158,7 @@ class ModelSpeedup:
...
@@ -158,7 +158,7 @@ class ModelSpeedup:
"""
"""
# TODO: scope name could be empty
# TODO: scope name could be empty
node_name
=
'.'
.
join
([
node
.
scopeName
(),
node
.
kind
(),
str
(
self
.
global_count
)])
node_name
=
'.'
.
join
([
node
.
scopeName
(),
node
.
kind
(),
str
(
self
.
global_count
)])
#print('
node
_
name:
'
, node_name)
_logger
.
debug
(
"expand non-prim node,
node
name:
%s"
,
node_name
)
self
.
global_count
+=
1
self
.
global_count
+=
1
op_type
=
node
.
kind
()
op_type
=
node
.
kind
()
...
@@ -173,7 +173,6 @@ class ModelSpeedup:
...
@@ -173,7 +173,6 @@ class ModelSpeedup:
input_name
=
_input
.
debugName
()
input_name
=
_input
.
debugName
()
if
input_name
in
output_to_node
and
output_to_node
[
input_name
]
in
nodes
:
if
input_name
in
output_to_node
and
output_to_node
[
input_name
]
in
nodes
:
predecessor_node
=
output_to_node
[
input_name
]
predecessor_node
=
output_to_node
[
input_name
]
#print("predecessor_node: ", predecessor_node)
if
predecessor_node
.
kind
().
startswith
(
'prim::'
):
if
predecessor_node
.
kind
().
startswith
(
'prim::'
):
node_group
.
append
(
predecessor_node
)
node_group
.
append
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
...
@@ -231,7 +230,7 @@ class ModelSpeedup:
...
@@ -231,7 +230,7 @@ class ModelSpeedup:
"""
"""
graph
=
self
.
trace_graph
.
graph
graph
=
self
.
trace_graph
.
graph
# if torch 1.4.0 is used, consider run torch._C._jit_pass_inline(graph) here
# if torch 1.4.0 is used, consider run torch._C._jit_pass_inline(graph) here
#
print
(graph)
#
_logger.debug
(graph)
# build output mapping, from output debugName to its node
# build output mapping, from output debugName to its node
output_to_node
=
dict
()
output_to_node
=
dict
()
# build input mapping, from input debugName to its node
# build input mapping, from input debugName to its node
...
@@ -301,10 +300,8 @@ class ModelSpeedup:
...
@@ -301,10 +300,8 @@ class ModelSpeedup:
m_inputs
.
append
(
_input
)
m_inputs
.
append
(
_input
)
elif
not
output_to_node
[
_input
]
in
nodes
:
elif
not
output_to_node
[
_input
]
in
nodes
:
m_inputs
.
append
(
_input
)
m_inputs
.
append
(
_input
)
print
(
"module node_name: "
,
module_name
)
if
module_name
==
''
:
if
module_name
==
''
:
for
n
in
nodes
:
_logger
.
warning
(
"module_name is empty string"
)
print
(
n
)
g_node
=
GNode
(
module_name
,
'module'
,
module_to_type
[
module_name
],
m_inputs
,
m_outputs
,
nodes
)
g_node
=
GNode
(
module_name
,
'module'
,
module_to_type
[
module_name
],
m_inputs
,
m_outputs
,
nodes
)
self
.
g_nodes
.
append
(
g_node
)
self
.
g_nodes
.
append
(
g_node
)
...
@@ -345,10 +342,7 @@ class ModelSpeedup:
...
@@ -345,10 +342,7 @@ class ModelSpeedup:
predecessors
=
[]
predecessors
=
[]
for
_input
in
self
.
name_to_gnode
[
module_name
].
inputs
:
for
_input
in
self
.
name_to_gnode
[
module_name
].
inputs
:
if
not
_input
in
self
.
output_to_gnode
:
if
not
_input
in
self
.
output_to_gnode
:
print
(
_input
)
_logger
.
debug
(
"cannot find gnode with %s as its output"
,
_input
)
if
not
_input
in
self
.
output_to_gnode
:
# TODO: check _input which does not have node
print
(
"output with no gnode: "
,
_input
)
else
:
else
:
g_node
=
self
.
output_to_gnode
[
_input
]
g_node
=
self
.
output_to_gnode
[
_input
]
predecessors
.
append
(
g_node
.
name
)
predecessors
.
append
(
g_node
.
name
)
...
@@ -407,15 +401,15 @@ class ModelSpeedup:
...
@@ -407,15 +401,15 @@ class ModelSpeedup:
self
.
inferred_masks
[
module_name
]
=
module_masks
self
.
inferred_masks
[
module_name
]
=
module_masks
m_type
=
self
.
name_to_gnode
[
module_name
].
op_type
m_type
=
self
.
name_to_gnode
[
module_name
].
op_type
print
(
"infer_module_mask: {}, module type: {}"
.
format
(
module_name
,
m_type
)
)
_logger
.
debug
(
"infer mask of module %s with op_type %s"
,
module_name
,
m_type
)
if
mask
is
not
None
:
if
mask
is
not
None
:
#print
("mask is not None")
_logger
.
debug
(
"mask is not None"
)
if
not
m_type
in
infer_from_mask
:
if
not
m_type
in
infer_from_mask
:
raise
RuntimeError
(
"Has not supported infering
\
raise
RuntimeError
(
"Has not supported infering
\
input/output shape from mask for module/function: `{}`"
.
format
(
m_type
))
input/output shape from mask for module/function: `{}`"
.
format
(
m_type
))
input_cmask
,
output_cmask
=
infer_from_mask
[
m_type
](
module_masks
,
mask
)
input_cmask
,
output_cmask
=
infer_from_mask
[
m_type
](
module_masks
,
mask
)
if
in_shape
is
not
None
:
if
in_shape
is
not
None
:
#print
("in_shape is not None")
_logger
.
debug
(
"in_shape is not None"
)
if
not
m_type
in
infer_from_inshape
:
if
not
m_type
in
infer_from_inshape
:
raise
RuntimeError
(
"Has not supported infering
\
raise
RuntimeError
(
"Has not supported infering
\
output shape from input shape for module/function: `{}`"
.
format
(
m_type
))
output shape from input shape for module/function: `{}`"
.
format
(
m_type
))
...
@@ -426,23 +420,19 @@ class ModelSpeedup:
...
@@ -426,23 +420,19 @@ class ModelSpeedup:
else
:
else
:
output_cmask
=
infer_from_inshape
[
m_type
](
module_masks
,
in_shape
)
output_cmask
=
infer_from_inshape
[
m_type
](
module_masks
,
in_shape
)
if
out_shape
is
not
None
:
if
out_shape
is
not
None
:
#print
("out_shape is not None")
_logger
.
debug
(
"out_shape is not None"
)
if
not
m_type
in
infer_from_outshape
:
if
not
m_type
in
infer_from_outshape
:
raise
RuntimeError
(
"Has not supported infering
\
raise
RuntimeError
(
"Has not supported infering
\
input shape from output shape for module/function: `{}`"
.
format
(
m_type
))
input shape from output shape for module/function: `{}`"
.
format
(
m_type
))
input_cmask
=
infer_from_outshape
[
m_type
](
module_masks
,
out_shape
)
input_cmask
=
infer_from_outshape
[
m_type
](
module_masks
,
out_shape
)
if
input_cmask
:
if
input_cmask
:
#print("input_cmask is not None")
predecessors
=
self
.
_find_predecessors
(
module_name
)
predecessors
=
self
.
_find_predecessors
(
module_name
)
for
_module_name
in
predecessors
:
for
_module_name
in
predecessors
:
print
(
"input_cmask, module_name: "
,
_module_name
)
self
.
infer_module_mask
(
_module_name
,
out_shape
=
input_cmask
)
self
.
infer_module_mask
(
_module_name
,
out_shape
=
input_cmask
)
if
output_cmask
:
if
output_cmask
:
#print("output_cmask is not None")
successors
=
self
.
_find_successors
(
module_name
)
successors
=
self
.
_find_successors
(
module_name
)
for
_module_name
in
successors
:
for
_module_name
in
successors
:
print
(
"output_cmask, module_name: "
,
_module_name
)
self
.
infer_module_mask
(
_module_name
,
in_shape
=
output_cmask
)
self
.
infer_module_mask
(
_module_name
,
in_shape
=
output_cmask
)
def
infer_modules_masks
(
self
):
def
infer_modules_masks
(
self
):
...
@@ -463,16 +453,19 @@ class ModelSpeedup:
...
@@ -463,16 +453,19 @@ class ModelSpeedup:
"""
"""
for
module_name
in
self
.
inferred_masks
:
for
module_name
in
self
.
inferred_masks
:
g_node
=
self
.
name_to_gnode
[
module_name
]
g_node
=
self
.
name_to_gnode
[
module_name
]
print
(
module_name
,
g_node
.
op_type
)
_logger
.
debug
(
"replace %s, in %s type, with op_type %s"
,
module_name
,
g_node
.
type
,
g_node
.
op_type
)
if
g_node
.
type
==
'module'
:
if
g_node
.
type
==
'module'
:
super_module
,
leaf_module
=
get_module_by_name
(
self
.
bound_model
,
module_name
)
super_module
,
leaf_module
=
get_module_by_name
(
self
.
bound_model
,
module_name
)
m_type
=
g_node
.
op_type
m_type
=
g_node
.
op_type
if
not
m_type
in
replace_module
:
if
not
m_type
in
replace_module
:
raise
RuntimeError
(
"Has not supported replacing the module: `{}`"
.
format
(
m_type
))
raise
RuntimeError
(
"Has not supported replacing the module: `{}`"
.
format
(
m_type
))
_logger
.
info
(
"replace module (name: %s, op_type: %s)"
,
module_name
,
m_type
)
compressed_module
=
replace_module
[
m_type
](
leaf_module
,
self
.
inferred_masks
[
module_name
])
compressed_module
=
replace_module
[
m_type
](
leaf_module
,
self
.
inferred_masks
[
module_name
])
setattr
(
super_module
,
module_name
.
split
(
'.'
)[
-
1
],
compressed_module
)
setattr
(
super_module
,
module_name
.
split
(
'.'
)[
-
1
],
compressed_module
)
elif
g_node
.
type
==
'func'
:
elif
g_node
.
type
==
'func'
:
print
(
"Warning: Cannot replace func..."
)
_logger
.
info
(
"Warning: cannot replace (name: %s, op_type: %s) which is func type"
,
module_name
,
g_node
.
op_type
)
else
:
else
:
raise
RuntimeError
(
"Unsupported GNode type: {}"
.
format
(
g_node
.
type
))
raise
RuntimeError
(
"Unsupported GNode type: {}"
.
format
(
g_node
.
type
))
...
@@ -482,10 +475,12 @@ class ModelSpeedup:
...
@@ -482,10 +475,12 @@ class ModelSpeedup:
first, do mask/shape inference,
first, do mask/shape inference,
second, replace modules
second, replace modules
"""
"""
#print("start to compress")
_logger
.
info
(
"start to speed up the model"
)
_logger
.
info
(
"infer module masks..."
)
self
.
infer_modules_masks
()
self
.
infer_modules_masks
()
_logger
.
info
(
"replace compressed modules..."
)
self
.
replace_compressed_modules
()
self
.
replace_compressed_modules
()
#print("finished compressing
")
_logger
.
info
(
"speedup done
"
)
# resume the model mode to that before the model is speed up
# resume the model mode to that before the model is speed up
if
self
.
is_training
:
if
self
.
is_training
:
self
.
bound_model
.
train
()
self
.
bound_model
.
train
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment