Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
03a02232
"include/ck/utility/array.hpp" did not exist on "05e046654c9a226444091806a418a77fe0e4a4c2"
Unverified
Commit
03a02232
authored
Jul 26, 2021
by
thomasschmied
Committed by
GitHub
Jul 26, 2021
Browse files
Change .to() to to_device() (#3963)
parent
5bf2cb19
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
9 deletions
+9
-9
nni/retiarii/oneshot/pytorch/darts.py
nni/retiarii/oneshot/pytorch/darts.py
+3
-3
nni/retiarii/oneshot/pytorch/proxyless.py
nni/retiarii/oneshot/pytorch/proxyless.py
+3
-3
nni/retiarii/oneshot/pytorch/random.py
nni/retiarii/oneshot/pytorch/random.py
+3
-3
No files found.
nni/retiarii/oneshot/pytorch/darts.py
View file @
03a02232
...
...
@@ -9,7 +9,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
..interface
import
BaseOneShotTrainer
from
.utils
import
AverageMeterGroup
,
replace_layer_choice
,
replace_input_choice
from
.utils
import
AverageMeterGroup
,
replace_layer_choice
,
replace_input_choice
,
to_device
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -160,8 +160,8 @@ class DartsTrainer(BaseOneShotTrainer):
self
.
model
.
train
()
meters
=
AverageMeterGroup
()
for
step
,
((
trn_X
,
trn_y
),
(
val_X
,
val_y
))
in
enumerate
(
zip
(
self
.
train_loader
,
self
.
valid_loader
)):
trn_X
,
trn_y
=
t
rn_X
.
to
(
self
.
device
),
trn_y
.
to
(
self
.
device
)
val_X
,
val_y
=
val_X
.
to
(
self
.
device
),
val_y
.
to
(
self
.
device
)
trn_X
,
trn_y
=
t
o_device
(
trn_X
,
self
.
device
),
to_device
(
trn_y
,
self
.
device
)
val_X
,
val_y
=
to_device
(
val_X
,
self
.
device
),
to_device
(
val_y
,
self
.
device
)
# phase 1. architecture step
self
.
ctrl_optim
.
zero_grad
()
...
...
nni/retiarii/oneshot/pytorch/proxyless.py
View file @
03a02232
...
...
@@ -8,7 +8,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
..interface
import
BaseOneShotTrainer
from
.utils
import
AverageMeterGroup
,
replace_layer_choice
,
replace_input_choice
from
.utils
import
AverageMeterGroup
,
replace_layer_choice
,
replace_input_choice
,
to_device
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -181,8 +181,8 @@ class ProxylessTrainer(BaseOneShotTrainer):
self
.
model
.
train
()
meters
=
AverageMeterGroup
()
for
step
,
((
trn_X
,
trn_y
),
(
val_X
,
val_y
))
in
enumerate
(
zip
(
self
.
train_loader
,
self
.
valid_loader
)):
trn_X
,
trn_y
=
t
rn_X
.
to
(
self
.
device
),
trn_y
.
to
(
self
.
device
)
val_X
,
val_y
=
val_X
.
to
(
self
.
device
),
val_y
.
to
(
self
.
device
)
trn_X
,
trn_y
=
t
o_device
(
trn_X
,
self
.
device
),
to_device
(
trn_y
,
self
.
device
)
val_X
,
val_y
=
to_device
(
val_X
,
self
.
device
),
to_device
(
val_y
,
self
.
device
)
if
epoch
>=
self
.
warmup_epochs
:
# 1) train architecture parameters
...
...
nni/retiarii/oneshot/pytorch/random.py
View file @
03a02232
...
...
@@ -8,7 +8,7 @@ import torch
import
torch.nn
as
nn
from
..interface
import
BaseOneShotTrainer
from
.utils
import
AverageMeterGroup
,
replace_layer_choice
,
replace_input_choice
from
.utils
import
AverageMeterGroup
,
replace_layer_choice
,
replace_input_choice
,
to_device
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -160,7 +160,7 @@ class SinglePathTrainer(BaseOneShotTrainer):
self
.
model
.
train
()
meters
=
AverageMeterGroup
()
for
step
,
(
x
,
y
)
in
enumerate
(
self
.
train_loader
):
x
,
y
=
x
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
x
,
y
=
to_device
(
x
,
self
.
device
),
to_device
(
y
,
self
.
device
)
self
.
optimizer
.
zero_grad
()
self
.
_resample
()
logits
=
self
.
model
(
x
)
...
...
@@ -180,7 +180,7 @@ class SinglePathTrainer(BaseOneShotTrainer):
meters
=
AverageMeterGroup
()
with
torch
.
no_grad
():
for
step
,
(
x
,
y
)
in
enumerate
(
self
.
valid_loader
):
x
,
y
=
x
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
x
,
y
=
to_device
(
x
,
self
.
device
),
to_device
(
y
,
self
.
device
)
self
.
_resample
()
logits
=
self
.
model
(
x
)
loss
=
self
.
loss
(
logits
,
y
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment