Unverified Commit d2605ca2 authored by colorjam's avatar colorjam Committed by GitHub
Browse files

Update NAS examples & doc (#3074)



* update title level

* Update examples & reproduction results of darts
Co-authored-by: default avatarcolorjam <im.cqyan@google.com>
parent a9b87c9a
......@@ -15,7 +15,7 @@ The above-mentioned example is meant to reproduce the results in the paper, we d
| | In paper | Reproduction |
| ---------------------- | ------------- | ------------ |
| First order (CIFAR10) | 3.00 +/- 0.14 | 2.78 |
| Second order (CIFAR10) | 2.76 +/- 0.09 | 2.89 |
| Second order (CIFAR10) | 2.76 +/- 0.09 | 2.80 |
## Examples
......
......@@ -14,7 +14,7 @@ import utils
from config import SearchConfig
from datasets.cifar import get_search_datasets
from model import Model
from nni.nas.pytorch.cdarts import CdartsTrainer
from nni.algorithms.nas.pytorch.cdarts import CdartsTrainer
if __name__ == "__main__":
config = SearchConfig()
......
......@@ -11,7 +11,7 @@ import torch.nn as nn
import datasets
from model import CNN
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
from nni.nas.pytorch.darts import DartsTrainer
from nni.algorithms.nas.pytorch.darts import DartsTrainer
from utils import accuracy
logger = logging.getLogger('nni')
......
......@@ -11,7 +11,7 @@ import torch.nn as nn
import datasets
from macro import GeneralNetwork
from micro import MicroNetwork
from nni.nas.pytorch import enas
from nni.algorithms.nas.pytorch import enas
from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint,
LRSchedulerCallback)
from utils import accuracy, reward_accuracy
......
......@@ -10,7 +10,7 @@ import torch
import torch.nn as nn
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint
from nni.nas.pytorch.pdarts import PdartsTrainer
from nni.algorithms.nas.pytorch.pdarts import PdartsTrainer
# prevent it to be reordered.
if True:
......
......@@ -7,7 +7,7 @@ import datasets
from putils import get_parameters
from model import SearchMobileNet
from nni.nas.pytorch.proxylessnas import ProxylessNasTrainer
from nni.algorithms.nas.pytorch.proxylessnas import ProxylessNasTrainer
from retrain import Retrain
logger = logging.getLogger('nni_proxylessnas')
......
......@@ -10,7 +10,7 @@ import torch
import torch.nn as nn
from nni.nas.pytorch.callbacks import LRSchedulerCallback
from nni.nas.pytorch.callbacks import ModelCheckpoint
from nni.nas.pytorch.spos import SPOSSupernetTrainingMutator, SPOSSupernetTrainer
from nni.algorithms.nas.pytorch.spos import SPOSSupernetTrainingMutator, SPOSSupernetTrainer
from dataloader import get_imagenet_iter_dali
from network import ShuffleNetV2OneShot, load_and_parse_state_dict
......
......@@ -11,7 +11,7 @@ import nni
import numpy as np
import torch
import torch.nn as nn
from nni.nas.pytorch.classic_nas import get_and_apply_next_architecture
from nni.algorithms.nas.pytorch.classic_nas import get_and_apply_next_architecture
from nni.nas.pytorch.utils import AverageMeterGroup
from dataloader import get_imagenet_iter_dali
......
......@@ -11,7 +11,7 @@ import numpy as np
import torch
import torch.nn as nn
from nni.nas.pytorch.enas import EnasMutator, EnasTrainer
from nni.algorithms.nas.pytorch.enas import EnasMutator, EnasTrainer
from nni.nas.pytorch.callbacks import LRSchedulerCallback
from dataloader import read_data_sst
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment