Unverified Commit 4b2dcab3 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Minor fixes to SPOS and ProxylessNAS examples (#4420)

parent 64ea284f
...@@ -72,7 +72,7 @@ Step 3. Train for Evaluation ...@@ -72,7 +72,7 @@ Step 3. Train for Evaluation
.. code-block:: bash .. code-block:: bash
python scratch.py python evaluation.py
By default, it will use ``architecture_final.json``. This architecture is provided by the official repo (converted into NNI format). You can use any architecture (e.g., the architecture found in step 2) with ``--fixed-arc`` option. By default, it will use ``architecture_final.json``. This architecture is provided by the official repo (converted into NNI format). You can use any architecture (e.g., the architecture found in step 2) with ``--fixed-arc`` option.
......
...@@ -5,12 +5,11 @@ import sys ...@@ -5,12 +5,11 @@ import sys
from argparse import ArgumentParser from argparse import ArgumentParser
import torch import torch
from nni.algorithms.nas.pytorch.proxylessnas import ProxylessNasTrainer
from torchvision import transforms from torchvision import transforms
from nni.retiarii.fixed import fixed_arch
import datasets import datasets
from model import SearchMobileNet from model import SearchMobileNet
from nni.algorithms.nas.pytorch.proxylessnas import ProxylessNasTrainer
from putils import LabelSmoothingLoss, accuracy, get_parameters from putils import LabelSmoothingLoss, accuracy, get_parameters
from retrain import Retrain from retrain import Retrain
...@@ -40,7 +39,7 @@ if __name__ == "__main__": ...@@ -40,7 +39,7 @@ if __name__ == "__main__":
parser.add_argument("--resize_scale", default=0.08, type=float) parser.add_argument("--resize_scale", default=0.08, type=float)
parser.add_argument("--distort_color", default='normal', type=str, choices=['normal', 'strong', 'None']) parser.add_argument("--distort_color", default='normal', type=str, choices=['normal', 'strong', 'None'])
# configurations for training mode # configurations for training mode
parser.add_argument("--train_mode", default='search', type=str, choices=['search_v1', 'search', 'retrain']) parser.add_argument("--train_mode", default='search', type=str, choices=['search', 'retrain'])
# configurations for search # configurations for search
parser.add_argument("--checkpoint_path", default='./search_mobile_net.pt', type=str) parser.add_argument("--checkpoint_path", default='./search_mobile_net.pt', type=str)
parser.add_argument("--arch_path", default='./arch_path.pt', type=str) parser.add_argument("--arch_path", default='./arch_path.pt', type=str)
...@@ -53,12 +52,23 @@ if __name__ == "__main__": ...@@ -53,12 +52,23 @@ if __name__ == "__main__":
logger.error('When --train_mode is retrain, --exported_arch_path must be specified.') logger.error('When --train_mode is retrain, --exported_arch_path must be specified.')
sys.exit(-1) sys.exit(-1)
model = SearchMobileNet(width_stages=[int(i) for i in args.width_stages.split(',')], if args.train_mode == 'retrain':
n_cell_stages=[int(i) for i in args.n_cell_stages.split(',')], assert os.path.isfile(args.exported_arch_path), \
stride_stages=[int(i) for i in args.stride_stages.split(',')], "exported_arch_path {} should be a file.".format(args.exported_arch_path)
n_classes=1000, with fixed_arch(args.exported_arch_path):
dropout_rate=args.dropout_rate, model = SearchMobileNet(width_stages=[int(i) for i in args.width_stages.split(',')],
bn_param=(args.bn_momentum, args.bn_eps)) n_cell_stages=[int(i) for i in args.n_cell_stages.split(',')],
stride_stages=[int(i) for i in args.stride_stages.split(',')],
n_classes=1000,
dropout_rate=args.dropout_rate,
bn_param=(args.bn_momentum, args.bn_eps))
else:
model = SearchMobileNet(width_stages=[int(i) for i in args.width_stages.split(',')],
n_cell_stages=[int(i) for i in args.n_cell_stages.split(',')],
stride_stages=[int(i) for i in args.stride_stages.split(',')],
n_classes=1000,
dropout_rate=args.dropout_rate,
bn_param=(args.bn_momentum, args.bn_eps))
logger.info('SearchMobileNet model create done') logger.info('SearchMobileNet model create done')
model.init_model() model.init_model()
logger.info('SearchMobileNet model init done') logger.info('SearchMobileNet model init done')
...@@ -125,28 +135,7 @@ if __name__ == "__main__": ...@@ -125,28 +135,7 @@ if __name__ == "__main__":
trainer.fit() trainer.fit()
print('Final architecture:', trainer.export()) print('Final architecture:', trainer.export())
json.dump(trainer.export(), open('checkpoint.json', 'w')) json.dump(trainer.export(), open('checkpoint.json', 'w'))
elif args.train_mode == 'search_v1':
# this is architecture search
logger.info('Creating ProxylessNasTrainer...')
trainer = ProxylessNasTrainer(model,
model_optim=optimizer,
train_loader=data_provider.train,
valid_loader=data_provider.valid,
device=device,
warmup=args.warmup,
ckpt_path=args.checkpoint_path,
arch_path=args.arch_path)
logger.info('Start to train with ProxylessNasTrainer...')
trainer.train()
logger.info('Training done')
trainer.export(args.arch_path)
logger.info('Best architecture exported in %s', args.arch_path)
elif args.train_mode == 'retrain': elif args.train_mode == 'retrain':
# this is retrain # this is retrain
from nni.nas.pytorch.fixed import apply_fixed_architecture
assert os.path.isfile(args.exported_arch_path), \
"exported_arch_path {} should be a file.".format(args.exported_arch_path)
apply_fixed_architecture(model, args.exported_arch_path)
trainer = Retrain(model, optimizer, device, data_provider, n_epochs=300) trainer = Retrain(model, optimizer, device, data_provider, n_epochs=300)
trainer.run() trainer.run()
...@@ -50,6 +50,9 @@ def _main(port): ...@@ -50,6 +50,9 @@ def _main(port):
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize([0.49139968, 0.48215827, 0.44653124], [0.24703233, 0.24348505, 0.26158768]) transforms.Normalize([0.49139968, 0.48215827, 0.44653124], [0.24703233, 0.24348505, 0.26158768])
] ]
# FIXME
# CIFAR10 is used here temporarily.
# Actually we should load weight from supernet and evaluate on imagenet.
train_dataset = serialize(CIFAR10, 'data', train=True, download=True, transform=transforms.Compose(transf + normalize)) train_dataset = serialize(CIFAR10, 'data', train=True, download=True, transform=transforms.Compose(transf + normalize))
test_dataset = serialize(CIFAR10, 'data', train=False, transform=transforms.Compose(normalize)) test_dataset = serialize(CIFAR10, 'data', train=False, transform=transforms.Compose(normalize))
...@@ -57,7 +60,8 @@ def _main(port): ...@@ -57,7 +60,8 @@ def _main(port):
val_dataloaders=pl.DataLoader(test_dataset, batch_size=64), val_dataloaders=pl.DataLoader(test_dataset, batch_size=64),
max_epochs=2, gpus=1) max_epochs=2, gpus=1)
simple_strategy = strategy.RegularizedEvolution(model_filter=LatencyFilter(threshold=100, predictor=base_predictor), population_size=2, cycles=2) simple_strategy = strategy.RegularizedEvolution(model_filter=LatencyFilter(threshold=100, predictor=base_predictor),
sample_size=1, population_size=2, cycles=2)
exp = RetiariiExperiment(base_model, trainer, strategy=simple_strategy) exp = RetiariiExperiment(base_model, trainer, strategy=simple_strategy)
exp_config = RetiariiExeConfig('local') exp_config = RetiariiExeConfig('local')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment