Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
617d9f32
Unverified
Commit
617d9f32
authored
Sep 06, 2021
by
Ningxin Zheng
Committed by
GitHub
Sep 06, 2021
Browse files
support directly load the mask (#4144)
parent
acb627cf
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
6 deletions
+10
-6
nni/compression/pytorch/speedup/compressor.py
nni/compression/pytorch/speedup/compressor.py
+10
-6
No files found.
nni/compression/pytorch/speedup/compressor.py
View file @
617d9f32
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
os
import
queue
import
queue
import
logging
import
logging
import
copy
import
copy
...
@@ -35,8 +35,8 @@ class ModelSpeedup:
...
@@ -35,8 +35,8 @@ class ModelSpeedup:
Note: The first dimension of the dummy_input should be the batchsize.
Note: The first dimension of the dummy_input should be the batchsize.
The dummy input for ```jit.trace```, users should put it on the right
The dummy input for ```jit.trace```, users should put it on the right
device.
device.
masks_file : str
masks_file : str
/dict
The path of user provided mask file
The path of user provided mask file
, or the mask object
map_location : str
map_location : str
the device on which masks are placed, same to map_location in ```torch.load```
the device on which masks are placed, same to map_location in ```torch.load```
batch_dim : int
batch_dim : int
...
@@ -63,9 +63,13 @@ class ModelSpeedup:
...
@@ -63,9 +63,13 @@ class ModelSpeedup:
# load the mask tensor to the same device with the dummy_input
# load the mask tensor to the same device with the dummy_input
# self.masks save the mask tensors pruned by the user and the infered
# self.masks save the mask tensors pruned by the user and the infered
# masks of the others modules
# masks of the others modules
if
isinstance
(
masks_file
,
str
)
and
os
.
path
.
exists
(
masks_file
):
self
.
masks
=
torch
.
load
(
self
.
masks
=
torch
.
load
(
masks_file
,
map_location
if
map_location
is
not
None
else
str
(
self
.
device
))
masks_file
,
map_location
if
map_location
is
not
None
else
str
(
self
.
device
))
elif
isinstance
(
masks_file
,
dict
):
self
.
masks
=
masks_file
else
:
raise
Exception
(
'Please provide the mask or the path of the mask file'
)
self
.
constant
=
{}
self
.
constant
=
{}
# self.internal_result save the internal output of the submodules
# self.internal_result save the internal output of the submodules
self
.
internal_result
=
{}
self
.
internal_result
=
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment