Unverified Commit 7b653a92 authored by Chi Song's avatar Chi Song Committed by GitHub
Browse files

[NAS] simplify log, and fix a bug on pdarts exporting (#1777)

parent 1398540e
......@@ -51,8 +51,8 @@ cd examples/nas/pdarts
python3 search.py
# train the best architecture, it's the same progress as darts.
cd examples/nas/darts
python3 retrain.py --arc-checkpoint ./checkpoints/epoch_2.json
cd ../darts
python3 retrain.py --arc-checkpoint ../pdarts/checkpoints/epoch_2.json
```
## Use NNI API
......
......@@ -4,24 +4,16 @@ from argparse import ArgumentParser
import torch
import torch.nn as nn
from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.pytorch.utils import AverageMeter
from torch.utils.tensorboard import SummaryWriter
import datasets
import utils
from model import CNN
from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.pytorch.utils import AverageMeter
logger = logging.getLogger()
fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging.Formatter.converter = time.localtime
formatter = logging.Formatter(fmt, '%m/%d/%Y, %I:%M:%S %p')
logger = logging.getLogger('nni')
std_out_info = logging.StreamHandler()
std_out_info.setFormatter(formatter)
logger.setLevel(logging.INFO)
logger.addHandler(std_out_info)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
writer = SummaryWriter()
......
......@@ -11,16 +11,7 @@ from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallbac
from nni.nas.pytorch.darts import DartsTrainer
from utils import accuracy
logger = logging.getLogger()
fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging.Formatter.converter = time.localtime
formatter = logging.Formatter(fmt, '%m/%d/%Y, %I:%M:%S %p')
std_out_info = logging.StreamHandler()
std_out_info.setFormatter(formatter)
logger.setLevel(logging.INFO)
logger.addHandler(std_out_info)
logger = logging.getLogger('nni')
if __name__ == "__main__":
parser = ArgumentParser("darts")
......
......@@ -9,19 +9,12 @@ import datasets
from macro import GeneralNetwork
from micro import MicroNetwork
from nni.nas.pytorch import enas
from nni.nas.pytorch.callbacks import LRSchedulerCallback, ArchitectureCheckpoint
from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint,
LRSchedulerCallback)
from utils import accuracy, reward_accuracy
logger = logging.getLogger()
logger = logging.getLogger('nni')
fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging.Formatter.converter = time.localtime
formatter = logging.Formatter(fmt, '%m/%d/%Y, %I:%M:%S %p')
std_out_info = logging.StreamHandler()
std_out_info.setFormatter(formatter)
logger.setLevel(logging.INFO)
logger.addHandler(std_out_info)
if __name__ == "__main__":
parser = ArgumentParser("enas")
......
......@@ -19,16 +19,9 @@ if True:
from model import CNN
import datasets
logger = logging.getLogger()
fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging.Formatter.converter = time.localtime
formatter = logging.Formatter(fmt, '%m/%d/%Y, %I:%M:%S %p')
logger = logging.getLogger('nni')
std_out_info = logging.StreamHandler()
std_out_info.setFormatter(formatter)
logger.setLevel(logging.INFO)
logger.addHandler(std_out_info)
if __name__ == "__main__":
parser = ArgumentParser("pdarts")
......
......@@ -5,7 +5,6 @@ import torch.nn as nn
from nni.nas.pytorch.utils import global_mutable_counting
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class Mutable(nn.Module):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
from nni.nas.pytorch.callbacks import LRSchedulerCallback
from nni.nas.pytorch.darts import DartsTrainer
from nni.nas.pytorch.trainer import BaseTrainer
from nni.nas.pytorch.trainer import BaseTrainer, TorchTensorEncoder
from .mutator import PdartsMutator
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class PdartsTrainer(BaseTrainer):
......@@ -55,7 +55,7 @@ class PdartsTrainer(BaseTrainer):
self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion, optimizer=optim,
callbacks=darts_callbacks, **self.darts_parameters)
logger.info("start pdarts training %s...", epoch)
logger.info("start pdarts training epoch %s...", epoch)
self.trainer.train()
......@@ -67,5 +67,10 @@ class PdartsTrainer(BaseTrainer):
def validate(self):
self.model.validate()
def export(self, file):
mutator_export = self.mutator.export()
with open(file, "w") as f:
json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
def checkpoint(self):
raise NotImplementedError("Not implemented yet")
......@@ -7,7 +7,6 @@ import torch
from .base_trainer import BaseTrainer
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
class TorchTensorEncoder(json.JSONEncoder):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment