Unverified Commit d2605ca2 authored by colorjam's avatar colorjam Committed by GitHub
Browse files

Update NAS examples & doc (#3074)



* update title level

* Update examples & reproduction results of darts
Co-authored-by: default avatarcolorjam <im.cqyan@google.com>
parent a9b87c9a
...@@ -15,7 +15,7 @@ The above-mentioned example is meant to reproduce the results in the paper, we d ...@@ -15,7 +15,7 @@ The above-mentioned example is meant to reproduce the results in the paper, we d
| | In paper | Reproduction | | | In paper | Reproduction |
| ---------------------- | ------------- | ------------ | | ---------------------- | ------------- | ------------ |
| First order (CIFAR10) | 3.00 +/- 0.14 | 2.78 | | First order (CIFAR10) | 3.00 +/- 0.14 | 2.78 |
| Second order (CIFAR10) | 2.76 +/- 0.09 | 2.89 | | Second order (CIFAR10) | 2.76 +/- 0.09 | 2.80 |
## Examples ## Examples
......
...@@ -14,7 +14,7 @@ import utils ...@@ -14,7 +14,7 @@ import utils
from config import SearchConfig from config import SearchConfig
from datasets.cifar import get_search_datasets from datasets.cifar import get_search_datasets
from model import Model from model import Model
from nni.nas.pytorch.cdarts import CdartsTrainer from nni.algorithms.nas.pytorch.cdarts import CdartsTrainer
if __name__ == "__main__": if __name__ == "__main__":
config = SearchConfig() config = SearchConfig()
......
...@@ -11,7 +11,7 @@ import torch.nn as nn ...@@ -11,7 +11,7 @@ import torch.nn as nn
import datasets import datasets
from model import CNN from model import CNN
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
from nni.nas.pytorch.darts import DartsTrainer from nni.algorithms.nas.pytorch.darts import DartsTrainer
from utils import accuracy from utils import accuracy
logger = logging.getLogger('nni') logger = logging.getLogger('nni')
......
...@@ -11,7 +11,7 @@ import torch.nn as nn ...@@ -11,7 +11,7 @@ import torch.nn as nn
import datasets import datasets
from macro import GeneralNetwork from macro import GeneralNetwork
from micro import MicroNetwork from micro import MicroNetwork
from nni.nas.pytorch import enas from nni.algorithms.nas.pytorch import enas
from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint, from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint,
LRSchedulerCallback) LRSchedulerCallback)
from utils import accuracy, reward_accuracy from utils import accuracy, reward_accuracy
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint from nni.nas.pytorch.callbacks import ArchitectureCheckpoint
from nni.nas.pytorch.pdarts import PdartsTrainer from nni.algorithms.nas.pytorch.pdarts import PdartsTrainer
# prevent it to be reordered. # prevent it to be reordered.
if True: if True:
......
...@@ -7,7 +7,7 @@ import datasets ...@@ -7,7 +7,7 @@ import datasets
from putils import get_parameters from putils import get_parameters
from model import SearchMobileNet from model import SearchMobileNet
from nni.nas.pytorch.proxylessnas import ProxylessNasTrainer from nni.algorithms.nas.pytorch.proxylessnas import ProxylessNasTrainer
from retrain import Retrain from retrain import Retrain
logger = logging.getLogger('nni_proxylessnas') logger = logging.getLogger('nni_proxylessnas')
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from nni.nas.pytorch.callbacks import LRSchedulerCallback from nni.nas.pytorch.callbacks import LRSchedulerCallback
from nni.nas.pytorch.callbacks import ModelCheckpoint from nni.nas.pytorch.callbacks import ModelCheckpoint
from nni.nas.pytorch.spos import SPOSSupernetTrainingMutator, SPOSSupernetTrainer from nni.algorithms.nas.pytorch.spos import SPOSSupernetTrainingMutator, SPOSSupernetTrainer
from dataloader import get_imagenet_iter_dali from dataloader import get_imagenet_iter_dali
from network import ShuffleNetV2OneShot, load_and_parse_state_dict from network import ShuffleNetV2OneShot, load_and_parse_state_dict
......
...@@ -11,7 +11,7 @@ import nni ...@@ -11,7 +11,7 @@ import nni
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.nas.pytorch.classic_nas import get_and_apply_next_architecture from nni.algorithms.nas.pytorch.classic_nas import get_and_apply_next_architecture
from nni.nas.pytorch.utils import AverageMeterGroup from nni.nas.pytorch.utils import AverageMeterGroup
from dataloader import get_imagenet_iter_dali from dataloader import get_imagenet_iter_dali
......
...@@ -11,7 +11,7 @@ import numpy as np ...@@ -11,7 +11,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.nas.pytorch.enas import EnasMutator, EnasTrainer from nni.algorithms.nas.pytorch.enas import EnasMutator, EnasTrainer
from nni.nas.pytorch.callbacks import LRSchedulerCallback from nni.nas.pytorch.callbacks import LRSchedulerCallback
from dataloader import read_data_sst from dataloader import read_data_sst
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment