Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
585b5c53
Unverified
Commit
585b5c53
authored
Feb 21, 2020
by
Yuge Zhang
Committed by
GitHub
Feb 21, 2020
Browse files
Fix apply_fixed_architecture device error and ENAS micro mask device error (#2088)
parent
89de4061
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
3 additions
and
3 deletions
+3
-3
examples/nas/darts/retrain.py
examples/nas/darts/retrain.py
+1
-1
examples/nas/enas/micro.py
examples/nas/enas/micro.py
+1
-1
examples/nas/proxylessnas/main.py
examples/nas/proxylessnas/main.py
+1
-1
No files found.
examples/nas/darts/retrain.py
View file @
585b5c53
...
...
@@ -120,7 +120,7 @@ if __name__ == "__main__":
dataset_train
,
dataset_valid
=
datasets
.
get_dataset
(
"cifar10"
,
cutout_length
=
16
)
model
=
CNN
(
32
,
3
,
36
,
10
,
args
.
layers
,
auxiliary
=
True
)
apply_fixed_architecture
(
model
,
args
.
arc_checkpoint
,
device
=
device
)
apply_fixed_architecture
(
model
,
args
.
arc_checkpoint
)
criterion
=
nn
.
CrossEntropyLoss
()
model
.
to
(
device
)
...
...
examples/nas/enas/micro.py
View file @
585b5c53
...
...
@@ -115,7 +115,7 @@ class ENASLayer(nn.Module):
nodes_used_mask
=
torch
.
zeros
(
self
.
num_nodes
+
2
,
dtype
=
torch
.
bool
,
device
=
prev
.
device
)
for
i
in
range
(
self
.
num_nodes
):
node_out
,
mask
=
self
.
nodes
[
i
](
prev_nodes_out
)
nodes_used_mask
[:
mask
.
size
(
0
)]
|=
mask
nodes_used_mask
[:
mask
.
size
(
0
)]
|=
mask
.
to
(
node_out
.
device
)
prev_nodes_out
.
append
(
node_out
)
unused_nodes
=
torch
.
cat
([
out
for
used
,
out
in
zip
(
nodes_used_mask
,
prev_nodes_out
)
if
not
used
],
1
)
...
...
examples/nas/proxylessnas/main.py
View file @
585b5c53
...
...
@@ -101,6 +101,6 @@ if __name__ == "__main__":
from
nni.nas.pytorch.fixed
import
apply_fixed_architecture
assert
os
.
path
.
isfile
(
args
.
exported_arch_path
),
\
"exported_arch_path {} should be a file."
.
format
(
args
.
exported_arch_path
)
apply_fixed_architecture
(
model
,
args
.
exported_arch_path
,
device
=
device
)
apply_fixed_architecture
(
model
,
args
.
exported_arch_path
)
trainer
=
Retrain
(
model
,
optimizer
,
device
,
data_provider
,
n_epochs
=
300
)
trainer
.
run
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment