Unverified Commit bcda469f authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Update on NAS examples (#3240)

parent 3423117d
......@@ -14,7 +14,7 @@ import torch.nn as nn
from genotypes import Genotype
from ops import PRIMITIVES
from nni.nas.pytorch.cdarts.utils import *
from nni.algorithms.nas.pytorch.cdarts.utils import *
def get_logger(file_path):
......
......@@ -7,7 +7,7 @@ from tensorflow.keras.optimizers import SGD
import nni
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice
from nni.nas.tensorflow.classic_nas import get_and_apply_next_architecture
from nni.algorithms.nas.tensorflow.classic_nas import get_and_apply_next_architecture
tf.get_logger().setLevel('ERROR')
......
......@@ -5,7 +5,7 @@
from tensorflow.keras.losses import Reduction, SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import SGD
from nni.nas.tensorflow import enas
from nni.algorithms.nas.tensorflow import enas
import datasets
from macro import GeneralNetwork
......
......@@ -8,7 +8,7 @@ from tensorflow.keras.losses import Reduction, SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import SGD
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice
from nni.nas.tensorflow.enas import EnasTrainer
from nni.algorithms.nas.tensorflow.enas import EnasTrainer
class Net(Model):
......@@ -55,7 +55,7 @@ class Net(Model):
def accuracy(truth, logits):
truth = tf.reshape(truth, -1)
truth = tf.reshape(truth, (-1, ))
predicted = tf.cast(tf.math.argmax(logits, axis=1), truth.dtype)
equal = tf.cast(predicted == truth, tf.int32)
return tf.math.reduce_sum(equal).numpy() / equal.shape[0]
......
import json
import logging
import os
import sys
......@@ -102,6 +103,7 @@ if __name__ == "__main__":
log_frequency=10)
trainer.fit()
print('Final architecture:', trainer.export())
json.dump(trainer.export(), open('checkpoint.json', 'w'))
elif args.train_mode == 'search_v1':
# this is architecture search
logger.info('Creating ProxylessNasTrainer...')
......
......@@ -85,7 +85,7 @@ def accuracy(output, target, topk=(1,)):
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
correct_k = correct[:k].reshape(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
......
......@@ -10,7 +10,7 @@ import torch.nn as nn
import datasets
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
from nni.nas.pytorch.darts import DartsTrainer
from nni.algorithms.nas.pytorch.darts import DartsTrainer
from utils import accuracy
from nni.nas.pytorch.search_space_zoo import DartsCell
......
......@@ -8,7 +8,7 @@ from torchvision import transforms
from torchvision.datasets import CIFAR10
from nni.nas.pytorch import mutables
from nni.nas.pytorch import enas
from nni.algorithms.nas.pytorch import enas
from utils import accuracy, reward_accuracy
from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint,
LRSchedulerCallback)
......
......@@ -7,7 +7,7 @@ from argparse import ArgumentParser
from torchvision import transforms
from torchvision.datasets import CIFAR10
from nni.nas.pytorch import enas
from nni.algorithms.nas.pytorch import enas
from utils import accuracy, reward_accuracy
from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint,
LRSchedulerCallback)
......
......@@ -10,13 +10,13 @@ import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from nni.nas.pytorch import enas
from nni.algorithms.nas.pytorch.darts import DartsTrainer
from nni.algorithms.nas.pytorch import enas
from nni.nas.pytorch.utils import AverageMeterGroup
from nni.nas.pytorch.nasbench201 import NASBench201Cell
from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.benchmarks.nasbench201 import query_nb201_trial_stats
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
from nni.nas.pytorch.darts import DartsTrainer
from utils import accuracy, reward_accuracy
import datasets
......
......@@ -36,6 +36,6 @@ def accuracy(output, target, topk=(1, 5)):
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
correct_k = correct[:k].reshape(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
......@@ -2,7 +2,6 @@
# Licensed under the MIT license.
export PYTHONPATH="$(pwd)"
export CUDA_VISIBLE_DEVICES=0
python3 -u retrain.py \
--train_ratio=1.0 \
......
......@@ -14,7 +14,7 @@ logger = logging.getLogger("nni.textnas")
def get_length(mask):
length = torch.sum(mask, 1)
length = length.long()
length = length.long().cpu()
return length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment