Unverified Commit f95d9ed8 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

proxylessnas update (#2069)

parent 9c0af6a7
...@@ -60,4 +60,4 @@ ProxylessNasMutator also implements the forward logic of the mutables (i.e., Lay ...@@ -60,4 +60,4 @@ ProxylessNasMutator also implements the forward logic of the mutables (i.e., Lay
## Reproduce Results ## Reproduce Results
Ongoing... To reproduce the result, we first run the search, we found that though it runs many epochs the chosen architecture converges at the first several epochs. This is probably induced by hyper-parameters or the implementation, we are working on it. The test accuracy of the found architecture is top1: 72.31, top5: 90.26.
...@@ -34,6 +34,7 @@ if __name__ == "__main__": ...@@ -34,6 +34,7 @@ if __name__ == "__main__":
# configurations for search # configurations for search
parser.add_argument("--checkpoint_path", default='./search_mobile_net.pt', type=str) parser.add_argument("--checkpoint_path", default='./search_mobile_net.pt', type=str)
parser.add_argument("--arch_path", default='./arch_path.pt', type=str) parser.add_argument("--arch_path", default='./arch_path.pt', type=str)
parser.add_argument("--no-warmup", dest='warmup', action='store_false')
# configurations for retrain # configurations for retrain
parser.add_argument("--exported_arch_path", default=None, type=str) parser.add_argument("--exported_arch_path", default=None, type=str)
...@@ -54,7 +55,7 @@ if __name__ == "__main__": ...@@ -54,7 +55,7 @@ if __name__ == "__main__":
# move network to GPU if available # move network to GPU if available
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device('cuda:0') device = torch.device('cuda')
else: else:
device = torch.device('cpu') device = torch.device('cpu')
...@@ -86,7 +87,7 @@ if __name__ == "__main__": ...@@ -86,7 +87,7 @@ if __name__ == "__main__":
train_loader=data_provider.train, train_loader=data_provider.train,
valid_loader=data_provider.valid, valid_loader=data_provider.valid,
device=device, device=device,
warmup=True, warmup=args.warmup,
ckpt_path=args.checkpoint_path, ckpt_path=args.checkpoint_path,
arch_path=args.arch_path) arch_path=args.arch_path)
...@@ -102,4 +103,4 @@ if __name__ == "__main__": ...@@ -102,4 +103,4 @@ if __name__ == "__main__":
"exported_arch_path {} should be a file.".format(args.exported_arch_path) "exported_arch_path {} should be a file.".format(args.exported_arch_path)
apply_fixed_architecture(model, args.exported_arch_path, device=device) apply_fixed_architecture(model, args.exported_arch_path, device=device)
trainer = Retrain(model, optimizer, device, data_provider, n_epochs=300) trainer = Retrain(model, optimizer, device, data_provider, n_epochs=300)
trainer.run() trainer.run()
\ No newline at end of file
...@@ -116,8 +116,6 @@ class AverageMeter: ...@@ -116,8 +116,6 @@ class AverageMeter:
n : int n : int
The weight of the new value. The weight of the new value.
""" """
if not isinstance(val, float) and not isinstance(val, int):
_logger.warning("Values passed to AverageMeter must be number, not %s.", type(val))
self.val = val self.val = val
self.sum += val * n self.sum += val * n
self.count += n self.count += n
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment