Unverified Commit 131fb2c1 authored by colorjam's avatar colorjam Committed by GitHub
Browse files

update pbt example to master (#2515)

parent 5a911b30
...@@ -13,12 +13,12 @@ from torchvision import datasets, transforms ...@@ -13,12 +13,12 @@ from torchvision import datasets, transforms
logger = logging.getLogger('mnist_pbt_tuner_pytorch_AutoML') logger = logging.getLogger('mnist_pbt_tuner_pytorch_AutoML')
class Net(nn.Module): class Net(nn.Module):
def __init__(self, hidden_size): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1) self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1) self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, hidden_size) self.fc1 = nn.Linear(4*4*50, 512)
self.fc2 = nn.Linear(hidden_size, 10) self.fc2 = nn.Linear(512, 10)
def forward(self, x): def forward(self, x):
x = F.relu(self.conv1(x)) x = F.relu(self.conv1(x))
...@@ -104,9 +104,7 @@ def main(args): ...@@ -104,9 +104,7 @@ def main(args):
])), ])),
batch_size=1000, shuffle=True, **kwargs) batch_size=1000, shuffle=True, **kwargs)
hidden_size = args['hidden_size'] model = Net().to(device)
model = Net(hidden_size=hidden_size).to(device)
save_checkpoint_dir = args['save_checkpoint_dir'] save_checkpoint_dir = args['save_checkpoint_dir']
save_checkpoint_path = os.path.join(save_checkpoint_dir, 'model.pth') save_checkpoint_path = os.path.join(save_checkpoint_dir, 'model.pth')
...@@ -146,11 +144,9 @@ def get_params(): ...@@ -146,11 +144,9 @@ def get_params():
# Training settings # Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument("--data_dir", type=str, parser.add_argument("--data_dir", type=str,
default='./tmp/pytorch/mnist/input_data', help="data directory") default='/tmp/pytorch/mnist/input_data', help="data directory")
parser.add_argument('--batch_size', type=int, default=64, metavar='N', parser.add_argument('--batch_size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)') help='input batch size for training (default: 64)')
parser.add_argument("--hidden_size", type=int, default=512, metavar='N',
help='hidden layer size (default: 512)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR', parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)') help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M', parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
......
{ {
"batch_size": {"_type":"choice", "_value": [16, 32, 64, 128]}, "batch_size": {"_type":"choice", "_value": [16, 32, 64, 128]},
"hidden_size":{"_type":"choice","_value":[128, 256, 512, 1024]},
"lr":{"_type":"choice","_value":[0.0001, 0.001, 0.01, 0.1]}, "lr":{"_type":"choice","_value":[0.0001, 0.001, 0.01, 0.1]},
"momentum":{"_type":"uniform","_value":[0, 1]} "momentum":{"_type":"uniform","_value":[0, 1]}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment