Unverified Commit b869dd48 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

fix mnist-pytorch example (#1596)

parent f60bf1d9
...@@ -5,6 +5,7 @@ This file is a modification of the official pytorch mnist example: ...@@ -5,6 +5,7 @@ This file is a modification of the official pytorch mnist example:
https://github.com/pytorch/examples/blob/master/mnist/main.py https://github.com/pytorch/examples/blob/master/mnist/main.py
""" """
import os
import argparse import argparse
import logging import logging
import nni import nni
...@@ -84,15 +85,18 @@ def main(args): ...@@ -84,15 +85,18 @@ def main(args):
device = torch.device("cuda" if use_cuda else "cpu") device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
data_dir = os.path.join(args['data_dir'], nni.get_trial_id())
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
datasets.MNIST(args['data_dir'], train=True, download=True, datasets.MNIST(data_dir, train=True, download=True,
transform=transforms.Compose([ transform=transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) transforms.Normalize((0.1307,), (0.3081,))
])), ])),
batch_size=args['batch_size'], shuffle=True, **kwargs) batch_size=args['batch_size'], shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader( test_loader = torch.utils.data.DataLoader(
datasets.MNIST(args['data_dir'], train=False, transform=transforms.Compose([ datasets.MNIST(data_dir, train=False, transform=transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) transforms.Normalize((0.1307,), (0.3081,))
])), ])),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment