Unverified Commit bcda469f authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Update on NAS examples (#3240)

parent 3423117d
...@@ -14,7 +14,7 @@ import torch.nn as nn ...@@ -14,7 +14,7 @@ import torch.nn as nn
from genotypes import Genotype from genotypes import Genotype
from ops import PRIMITIVES from ops import PRIMITIVES
from nni.nas.pytorch.cdarts.utils import * from nni.algorithms.nas.pytorch.cdarts.utils import *
def get_logger(file_path): def get_logger(file_path):
......
...@@ -7,7 +7,7 @@ from tensorflow.keras.optimizers import SGD ...@@ -7,7 +7,7 @@ from tensorflow.keras.optimizers import SGD
import nni import nni
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice from nni.nas.tensorflow.mutables import LayerChoice, InputChoice
from nni.nas.tensorflow.classic_nas import get_and_apply_next_architecture from nni.algorithms.nas.tensorflow.classic_nas import get_and_apply_next_architecture
tf.get_logger().setLevel('ERROR') tf.get_logger().setLevel('ERROR')
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from tensorflow.keras.losses import Reduction, SparseCategoricalCrossentropy from tensorflow.keras.losses import Reduction, SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import SGD from tensorflow.keras.optimizers import SGD
from nni.nas.tensorflow import enas from nni.algorithms.nas.tensorflow import enas
import datasets import datasets
from macro import GeneralNetwork from macro import GeneralNetwork
......
...@@ -8,7 +8,7 @@ from tensorflow.keras.losses import Reduction, SparseCategoricalCrossentropy ...@@ -8,7 +8,7 @@ from tensorflow.keras.losses import Reduction, SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import SGD from tensorflow.keras.optimizers import SGD
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice from nni.nas.tensorflow.mutables import LayerChoice, InputChoice
from nni.nas.tensorflow.enas import EnasTrainer from nni.algorithms.nas.tensorflow.enas import EnasTrainer
class Net(Model): class Net(Model):
...@@ -55,7 +55,7 @@ class Net(Model): ...@@ -55,7 +55,7 @@ class Net(Model):
def accuracy(truth, logits): def accuracy(truth, logits):
truth = tf.reshape(truth, -1) truth = tf.reshape(truth, (-1, ))
predicted = tf.cast(tf.math.argmax(logits, axis=1), truth.dtype) predicted = tf.cast(tf.math.argmax(logits, axis=1), truth.dtype)
equal = tf.cast(predicted == truth, tf.int32) equal = tf.cast(predicted == truth, tf.int32)
return tf.math.reduce_sum(equal).numpy() / equal.shape[0] return tf.math.reduce_sum(equal).numpy() / equal.shape[0]
......
import json
import logging import logging
import os import os
import sys import sys
...@@ -102,6 +103,7 @@ if __name__ == "__main__": ...@@ -102,6 +103,7 @@ if __name__ == "__main__":
log_frequency=10) log_frequency=10)
trainer.fit() trainer.fit()
print('Final architecture:', trainer.export()) print('Final architecture:', trainer.export())
json.dump(trainer.export(), open('checkpoint.json', 'w'))
elif args.train_mode == 'search_v1': elif args.train_mode == 'search_v1':
# this is architecture search # this is architecture search
logger.info('Creating ProxylessNasTrainer...') logger.info('Creating ProxylessNasTrainer...')
......
...@@ -85,7 +85,7 @@ def accuracy(output, target, topk=(1,)): ...@@ -85,7 +85,7 @@ def accuracy(output, target, topk=(1,)):
res = dict() res = dict()
for k in topk: for k in topk:
correct_k = correct[:k].view(-1).float().sum(0) correct_k = correct[:k].reshape(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res return res
......
...@@ -10,7 +10,7 @@ import torch.nn as nn ...@@ -10,7 +10,7 @@ import torch.nn as nn
import datasets import datasets
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
from nni.nas.pytorch.darts import DartsTrainer from nni.algorithms.nas.pytorch.darts import DartsTrainer
from utils import accuracy from utils import accuracy
from nni.nas.pytorch.search_space_zoo import DartsCell from nni.nas.pytorch.search_space_zoo import DartsCell
......
...@@ -8,7 +8,7 @@ from torchvision import transforms ...@@ -8,7 +8,7 @@ from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from nni.nas.pytorch import mutables from nni.nas.pytorch import mutables
from nni.nas.pytorch import enas from nni.algorithms.nas.pytorch import enas
from utils import accuracy, reward_accuracy from utils import accuracy, reward_accuracy
from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint, from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint,
LRSchedulerCallback) LRSchedulerCallback)
......
...@@ -7,7 +7,7 @@ from argparse import ArgumentParser ...@@ -7,7 +7,7 @@ from argparse import ArgumentParser
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from nni.nas.pytorch import enas from nni.algorithms.nas.pytorch import enas
from utils import accuracy, reward_accuracy from utils import accuracy, reward_accuracy
from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint, from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint,
LRSchedulerCallback) LRSchedulerCallback)
......
...@@ -10,13 +10,13 @@ import torch ...@@ -10,13 +10,13 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from nni.nas.pytorch import enas from nni.algorithms.nas.pytorch.darts import DartsTrainer
from nni.algorithms.nas.pytorch import enas
from nni.nas.pytorch.utils import AverageMeterGroup from nni.nas.pytorch.utils import AverageMeterGroup
from nni.nas.pytorch.nasbench201 import NASBench201Cell from nni.nas.pytorch.nasbench201 import NASBench201Cell
from nni.nas.pytorch.fixed import apply_fixed_architecture from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.benchmarks.nasbench201 import query_nb201_trial_stats from nni.nas.benchmarks.nasbench201 import query_nb201_trial_stats
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
from nni.nas.pytorch.darts import DartsTrainer
from utils import accuracy, reward_accuracy from utils import accuracy, reward_accuracy
import datasets import datasets
......
...@@ -36,6 +36,6 @@ def accuracy(output, target, topk=(1, 5)): ...@@ -36,6 +36,6 @@ def accuracy(output, target, topk=(1, 5)):
res = dict() res = dict()
for k in topk: for k in topk:
correct_k = correct[:k].view(-1).float().sum(0) correct_k = correct[:k].reshape(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res return res
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# Licensed under the MIT license. # Licensed under the MIT license.
export PYTHONPATH="$(pwd)" export PYTHONPATH="$(pwd)"
export CUDA_VISIBLE_DEVICES=0
python3 -u retrain.py \ python3 -u retrain.py \
--train_ratio=1.0 \ --train_ratio=1.0 \
......
...@@ -14,7 +14,7 @@ logger = logging.getLogger("nni.textnas") ...@@ -14,7 +14,7 @@ logger = logging.getLogger("nni.textnas")
def get_length(mask): def get_length(mask):
length = torch.sum(mask, 1) length = torch.sum(mask, 1)
length = length.long() length = length.long().cpu()
return length return length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment