Unverified Commit f95d9ed8 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

proxylessnas update (#2069)

parent 9c0af6a7
......@@ -60,4 +60,4 @@ ProxylessNasMutator also implements the forward logic of the mutables (i.e., Lay
## Reproduce Results
Ongoing...
To reproduce the result, we first run the search, we found that though it runs many epochs the chosen architecture converges at the first several epochs. This is probably induced by hyper-parameters or the implementation, we are working on it. The test accuracy of the found architecture is top1: 72.31, top5: 90.26.
......@@ -34,6 +34,7 @@ if __name__ == "__main__":
# configurations for search
parser.add_argument("--checkpoint_path", default='./search_mobile_net.pt', type=str)
parser.add_argument("--arch_path", default='./arch_path.pt', type=str)
parser.add_argument("--no-warmup", dest='warmup', action='store_false')
# configurations for retrain
parser.add_argument("--exported_arch_path", default=None, type=str)
......@@ -54,7 +55,7 @@ if __name__ == "__main__":
# move network to GPU if available
if torch.cuda.is_available():
device = torch.device('cuda:0')
device = torch.device('cuda')
else:
device = torch.device('cpu')
......@@ -86,7 +87,7 @@ if __name__ == "__main__":
train_loader=data_provider.train,
valid_loader=data_provider.valid,
device=device,
warmup=True,
warmup=args.warmup,
ckpt_path=args.checkpoint_path,
arch_path=args.arch_path)
......
......@@ -116,8 +116,6 @@ class AverageMeter:
n : int
The weight of the new value.
"""
if not isinstance(val, float) and not isinstance(val, int):
_logger.warning("Values passed to AverageMeter must be number, not %s.", type(val))
self.val = val
self.sum += val * n
self.count += n
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment