Unverified Commit 7b653a92 authored by Chi Song's avatar Chi Song Committed by GitHub
Browse files

[NAS] simplify log, and fix a bug on pdarts exporting (#1777)

parent 1398540e
...@@ -51,8 +51,8 @@ cd examples/nas/pdarts ...@@ -51,8 +51,8 @@ cd examples/nas/pdarts
python3 search.py python3 search.py
# train the best architecture, it's the same progress as darts. # train the best architecture, it's the same progress as darts.
cd examples/nas/darts cd ../darts
python3 retrain.py --arc-checkpoint ./checkpoints/epoch_2.json python3 retrain.py --arc-checkpoint ../pdarts/checkpoints/epoch_2.json
``` ```
## Use NNI API ## Use NNI API
......
...@@ -4,24 +4,16 @@ from argparse import ArgumentParser ...@@ -4,24 +4,16 @@ from argparse import ArgumentParser
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.pytorch.utils import AverageMeter
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
import datasets import datasets
import utils import utils
from model import CNN from model import CNN
from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.pytorch.utils import AverageMeter
logger = logging.getLogger() logger = logging.getLogger('nni')
fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging.Formatter.converter = time.localtime
formatter = logging.Formatter(fmt, '%m/%d/%Y, %I:%M:%S %p')
std_out_info = logging.StreamHandler()
std_out_info.setFormatter(formatter)
logger.setLevel(logging.INFO)
logger.addHandler(std_out_info)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
writer = SummaryWriter() writer = SummaryWriter()
......
...@@ -11,16 +11,7 @@ from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallbac ...@@ -11,16 +11,7 @@ from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallbac
from nni.nas.pytorch.darts import DartsTrainer from nni.nas.pytorch.darts import DartsTrainer
from utils import accuracy from utils import accuracy
logger = logging.getLogger() logger = logging.getLogger('nni')
fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging.Formatter.converter = time.localtime
formatter = logging.Formatter(fmt, '%m/%d/%Y, %I:%M:%S %p')
std_out_info = logging.StreamHandler()
std_out_info.setFormatter(formatter)
logger.setLevel(logging.INFO)
logger.addHandler(std_out_info)
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser("darts") parser = ArgumentParser("darts")
......
...@@ -9,19 +9,12 @@ import datasets ...@@ -9,19 +9,12 @@ import datasets
from macro import GeneralNetwork from macro import GeneralNetwork
from micro import MicroNetwork from micro import MicroNetwork
from nni.nas.pytorch import enas from nni.nas.pytorch import enas
from nni.nas.pytorch.callbacks import LRSchedulerCallback, ArchitectureCheckpoint from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint,
LRSchedulerCallback)
from utils import accuracy, reward_accuracy from utils import accuracy, reward_accuracy
logger = logging.getLogger() logger = logging.getLogger('nni')
fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging.Formatter.converter = time.localtime
formatter = logging.Formatter(fmt, '%m/%d/%Y, %I:%M:%S %p')
std_out_info = logging.StreamHandler()
std_out_info.setFormatter(formatter)
logger.setLevel(logging.INFO)
logger.addHandler(std_out_info)
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser("enas") parser = ArgumentParser("enas")
......
...@@ -19,16 +19,9 @@ if True: ...@@ -19,16 +19,9 @@ if True:
from model import CNN from model import CNN
import datasets import datasets
logger = logging.getLogger()
fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s' logger = logging.getLogger('nni')
logging.Formatter.converter = time.localtime
formatter = logging.Formatter(fmt, '%m/%d/%Y, %I:%M:%S %p')
std_out_info = logging.StreamHandler()
std_out_info.setFormatter(formatter)
logger.setLevel(logging.INFO)
logger.addHandler(std_out_info)
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser("pdarts") parser = ArgumentParser("pdarts")
......
...@@ -5,7 +5,6 @@ import torch.nn as nn ...@@ -5,7 +5,6 @@ import torch.nn as nn
from nni.nas.pytorch.utils import global_mutable_counting from nni.nas.pytorch.utils import global_mutable_counting
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class Mutable(nn.Module): class Mutable(nn.Module):
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import json
import logging import logging
from nni.nas.pytorch.callbacks import LRSchedulerCallback from nni.nas.pytorch.callbacks import LRSchedulerCallback
from nni.nas.pytorch.darts import DartsTrainer from nni.nas.pytorch.darts import DartsTrainer
from nni.nas.pytorch.trainer import BaseTrainer from nni.nas.pytorch.trainer import BaseTrainer, TorchTensorEncoder
from .mutator import PdartsMutator from .mutator import PdartsMutator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class PdartsTrainer(BaseTrainer): class PdartsTrainer(BaseTrainer):
...@@ -55,7 +55,7 @@ class PdartsTrainer(BaseTrainer): ...@@ -55,7 +55,7 @@ class PdartsTrainer(BaseTrainer):
self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion, optimizer=optim, self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion, optimizer=optim,
callbacks=darts_callbacks, **self.darts_parameters) callbacks=darts_callbacks, **self.darts_parameters)
logger.info("start pdarts training %s...", epoch) logger.info("start pdarts training epoch %s...", epoch)
self.trainer.train() self.trainer.train()
...@@ -67,5 +67,10 @@ class PdartsTrainer(BaseTrainer): ...@@ -67,5 +67,10 @@ class PdartsTrainer(BaseTrainer):
def validate(self): def validate(self):
self.model.validate() self.model.validate()
def export(self, file):
mutator_export = self.mutator.export()
with open(file, "w") as f:
json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
def checkpoint(self): def checkpoint(self):
raise NotImplementedError("Not implemented yet") raise NotImplementedError("Not implemented yet")
...@@ -7,7 +7,6 @@ import torch ...@@ -7,7 +7,6 @@ import torch
from .base_trainer import BaseTrainer from .base_trainer import BaseTrainer
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
class TorchTensorEncoder(json.JSONEncoder): class TorchTensorEncoder(json.JSONEncoder):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment