Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
b869dd48
Unverified
Commit
b869dd48
authored
Oct 10, 2019
by
chicm-ms
Committed by
GitHub
Oct 10, 2019
Browse files
fix mnist-pytorch example (#1596)
parent
f60bf1d9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
2 deletions
+6
-2
examples/trials/mnist-pytorch/mnist.py
examples/trials/mnist-pytorch/mnist.py
+6
-2
No files found.
examples/trials/mnist-pytorch/mnist.py
View file @
b869dd48
...
@@ -5,6 +5,7 @@ This file is a modification of the official pytorch mnist example:
...
@@ -5,6 +5,7 @@ This file is a modification of the official pytorch mnist example:
https://github.com/pytorch/examples/blob/master/mnist/main.py
https://github.com/pytorch/examples/blob/master/mnist/main.py
"""
"""
import
os
import
argparse
import
argparse
import
logging
import
logging
import
nni
import
nni
...
@@ -84,15 +85,18 @@ def main(args):
...
@@ -84,15 +85,18 @@ def main(args):
device
=
torch
.
device
(
"cuda"
if
use_cuda
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
use_cuda
else
"cpu"
)
kwargs
=
{
'num_workers'
:
1
,
'pin_memory'
:
True
}
if
use_cuda
else
{}
kwargs
=
{
'num_workers'
:
1
,
'pin_memory'
:
True
}
if
use_cuda
else
{}
data_dir
=
os
.
path
.
join
(
args
[
'data_dir'
],
nni
.
get_trial_id
())
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
MNIST
(
args
[
'
data_dir
'
]
,
train
=
True
,
download
=
True
,
datasets
.
MNIST
(
data_dir
,
train
=
True
,
download
=
True
,
transform
=
transforms
.
Compose
([
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))
])),
])),
batch_size
=
args
[
'batch_size'
],
shuffle
=
True
,
**
kwargs
)
batch_size
=
args
[
'batch_size'
],
shuffle
=
True
,
**
kwargs
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
MNIST
(
args
[
'
data_dir
'
]
,
train
=
False
,
transform
=
transforms
.
Compose
([
datasets
.
MNIST
(
data_dir
,
train
=
False
,
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))
])),
])),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment