Unverified Commit e408e146 authored by Tab Zhang's avatar Tab Zhang Committed by GitHub
Browse files

Search space zoo example fix (#2801)

parent 593d2d20
...@@ -14,7 +14,7 @@ from nni.nas.pytorch.darts import DartsTrainer ...@@ -14,7 +14,7 @@ from nni.nas.pytorch.darts import DartsTrainer
from utils import accuracy from utils import accuracy
from nni.nas.pytorch.search_space_zoo import DartsCell from nni.nas.pytorch.search_space_zoo import DartsCell
from darts_search_space import DartsStackedCells from darts_stack_cells import DartsStackedCells
logger = logging.getLogger('nni') logger = logging.getLogger('nni')
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import torch.nn as nn import torch.nn as nn
import ops from nni.nas.pytorch.search_space_zoo.darts_ops import DropPath
class DartsStackedCells(nn.Module): class DartsStackedCells(nn.Module):
...@@ -79,5 +79,5 @@ class DartsStackedCells(nn.Module): ...@@ -79,5 +79,5 @@ class DartsStackedCells(nn.Module):
def drop_path_prob(self, p): def drop_path_prob(self, p):
for module in self.modules(): for module in self.modules():
if isinstance(module, ops.DropPath): if isinstance(module, DropPath):
module.p = p module.p = p
...@@ -58,7 +58,6 @@ if __name__ == "__main__": ...@@ -58,7 +58,6 @@ if __name__ == "__main__":
parser = ArgumentParser("enas") parser = ArgumentParser("enas")
parser.add_argument("--batch-size", default=128, type=int) parser.add_argument("--batch-size", default=128, type=int)
parser.add_argument("--log-frequency", default=10, type=int) parser.add_argument("--log-frequency", default=10, type=int)
# parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)") parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)")
parser.add_argument("--visualization", default=False, action="store_true") parser.add_argument("--visualization", default=False, action="store_true")
args = parser.parse_args() args = parser.parse_args()
...@@ -71,7 +70,6 @@ if __name__ == "__main__": ...@@ -71,7 +70,6 @@ if __name__ == "__main__":
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4) optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001)
trainer = enas.EnasTrainer(model, trainer = enas.EnasTrainer(model,
loss=criterion, loss=criterion,
metrics=accuracy, metrics=accuracy,
......
...@@ -62,7 +62,7 @@ class MicroNetwork(nn.Module): ...@@ -62,7 +62,7 @@ class MicroNetwork(nn.Module):
reduction = False reduction = False
if layer_id in pool_layers: if layer_id in pool_layers:
c_cur, reduction = c_p * 2, True c_cur, reduction = c_p * 2, True
self.layers.append(ENASMicroLayer(self.layers, num_nodes, c_pp, c_p, c_cur, reduction)) self.layers.append(ENASMicroLayer(num_nodes, c_pp, c_p, c_cur, reduction))
if reduction: if reduction:
c_pp = c_p = c_cur c_pp = c_p = c_cur
c_pp, c_p = c_p, c_cur c_pp, c_p = c_p, c_cur
...@@ -98,7 +98,6 @@ if __name__ == "__main__": ...@@ -98,7 +98,6 @@ if __name__ == "__main__":
parser = ArgumentParser("enas") parser = ArgumentParser("enas")
parser.add_argument("--batch-size", default=128, type=int) parser.add_argument("--batch-size", default=128, type=int)
parser.add_argument("--log-frequency", default=10, type=int) parser.add_argument("--log-frequency", default=10, type=int)
# parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)") parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)")
parser.add_argument("--visualization", default=False, action="store_true") parser.add_argument("--visualization", default=False, action="store_true")
args = parser.parse_args() args = parser.parse_args()
......
...@@ -79,7 +79,6 @@ class ENASMicroLayer(nn.Module): ...@@ -79,7 +79,6 @@ class ENASMicroLayer(nn.Module):
""" """
def __init__(self, num_nodes, in_channels_pp, in_channels_p, out_channels, reduction): def __init__(self, num_nodes, in_channels_pp, in_channels_p, out_channels, reduction):
super().__init__() super().__init__()
print(in_channels_pp, in_channels_p, out_channels, reduction)
self.reduction = reduction self.reduction = reduction
if self.reduction: if self.reduction:
self.reduce0 = FactorizedReduce(in_channels_pp, out_channels, affine=False) self.reduce0 = FactorizedReduce(in_channels_pp, out_channels, affine=False)
...@@ -160,7 +159,7 @@ class ENASMacroLayer(mutables.MutableScope): ...@@ -160,7 +159,7 @@ class ENASMacroLayer(mutables.MutableScope):
PoolBranch('avg', in_filters, out_filters, 3, 1, 1), PoolBranch('avg', in_filters, out_filters, 3, 1, 1),
PoolBranch('max', in_filters, out_filters, 3, 1, 1) PoolBranch('max', in_filters, out_filters, 3, 1, 1)
]) ])
if prev_labels > 0: if prev_labels:
self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None) self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None)
else: else:
self.skipconnect = None self.skipconnect = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment