Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
b8c0fb6e
Unverified
Commit
b8c0fb6e
authored
Feb 15, 2020
by
QuanluZhang
Committed by
GitHub
Feb 15, 2020
Browse files
compression speedup: add init file (#2063)
parent
b4ab371b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
7 additions
and
8 deletions
+7
-8
src/sdk/pynni/nni/compression/speedup/__init__.py
src/sdk/pynni/nni/compression/speedup/__init__.py
+0
-0
src/sdk/pynni/nni/compression/speedup/torch/compress_modules.py
...k/pynni/nni/compression/speedup/torch/compress_modules.py
+1
-1
src/sdk/pynni/nni/compression/speedup/torch/compressor.py
src/sdk/pynni/nni/compression/speedup/torch/compressor.py
+1
-1
src/sdk/pynni/nni/compression/speedup/torch/infer_shape.py
src/sdk/pynni/nni/compression/speedup/torch/infer_shape.py
+5
-6
No files found.
src/sdk/pynni/nni/compression/speedup/__init__.py
0 → 100644
View file @
b8c0fb6e
src/sdk/pynni/nni/compression/speedup/torch/compress_modules.py
View file @
b8c0fb6e
...
...
@@ -2,7 +2,7 @@
# Licensed under the MIT license.
import
torch
from
.infer_shape
import
CoarseMask
,
ModuleMasks
from
.infer_shape
import
ModuleMasks
replace_module
=
{
'BatchNorm2d'
:
lambda
module
,
mask
:
replace_batchnorm2d
(
module
,
mask
),
...
...
src/sdk/pynni/nni/compression/speedup/torch/compressor.py
View file @
b8c0fb6e
...
...
@@ -379,7 +379,7 @@ class ModelSpeedup:
def
infer_module_mask
(
self
,
module_name
,
mask
=
None
,
in_shape
=
None
,
out_shape
=
None
):
"""
Infer input shape / output shape based on the module's weight mask / input shape / output shape.
For a module:
Infer its input and output shape from its weight mask
Infer its output shape from its input shape
...
...
src/sdk/pynni/nni/compression/speedup/torch/infer_shape.py
View file @
b8c0fb6e
...
...
@@ -56,7 +56,7 @@ class CoarseMask:
s
.
add
(
num
)
for
num
in
index_b
:
s
.
add
(
num
)
return
torch
.
tensor
(
sorted
(
s
))
return
torch
.
tensor
(
sorted
(
s
))
# pylint: disable=not-callable
def
merge
(
self
,
cmask
):
"""
...
...
@@ -98,7 +98,7 @@ class ModuleMasks:
self
.
param_masks
=
dict
()
self
.
input_mask
=
None
self
.
output_mask
=
None
def
set_param_masks
(
self
,
name
,
mask
):
"""
Parameters
...
...
@@ -217,7 +217,7 @@ def view_inshape(module_masks, mask, shape):
TODO: consider replace tensor.view with nn.Flatten, because tensor.view is not
included in module, thus, cannot be replaced by our framework.
Parameters
----------
module_masks : ModuleMasks
...
...
@@ -250,7 +250,7 @@ def view_inshape(module_masks, mask, shape):
step_size
=
shape
[
'in_shape'
][
2
]
*
shape
[
'in_shape'
][
3
]
for
loc
in
mask
.
mask_index
[
1
]:
index
.
extend
([
loc
*
step_size
+
i
for
i
in
range
(
step_size
)])
output_cmask
.
add_index_mask
(
dim
=
1
,
index
=
torch
.
tensor
(
index
))
output_cmask
.
add_index_mask
(
dim
=
1
,
index
=
torch
.
tensor
(
index
))
# pylint: disable=not-callable
module_masks
.
set_output_mask
(
output_cmask
)
return
output_cmask
...
...
@@ -373,7 +373,6 @@ def conv2d_mask(module_masks, mask):
"""
assert
'weight'
in
mask
assert
isinstance
(
mask
[
'weight'
],
torch
.
Tensor
)
cmask
=
None
weight_mask
=
mask
[
'weight'
]
shape
=
weight_mask
.
size
()
ones
=
torch
.
ones
(
shape
[
1
:]).
to
(
weight_mask
.
device
)
...
...
@@ -451,7 +450,7 @@ def conv2d_outshape(module_masks, mask):
The ModuleMasks instance of the conv2d
mask : CoarseMask
The mask of its output tensor
Returns
-------
CoarseMask
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment