Unverified Commit e408e146 authored by Tab Zhang's avatar Tab Zhang Committed by GitHub
Browse files

Search space zoo example fix (#2801)

parent 593d2d20
......@@ -14,7 +14,7 @@ from nni.nas.pytorch.darts import DartsTrainer
from utils import accuracy
from nni.nas.pytorch.search_space_zoo import DartsCell
from darts_search_space import DartsStackedCells
from darts_stack_cells import DartsStackedCells
logger = logging.getLogger('nni')
......
......@@ -2,7 +2,7 @@
# Licensed under the MIT license.
import torch.nn as nn
import ops
from nni.nas.pytorch.search_space_zoo.darts_ops import DropPath
class DartsStackedCells(nn.Module):
......@@ -79,5 +79,5 @@ class DartsStackedCells(nn.Module):
def drop_path_prob(self, p):
for module in self.modules():
if isinstance(module, ops.DropPath):
if isinstance(module, DropPath):
module.p = p
......@@ -58,7 +58,6 @@ if __name__ == "__main__":
parser = ArgumentParser("enas")
parser.add_argument("--batch-size", default=128, type=int)
parser.add_argument("--log-frequency", default=10, type=int)
# parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)")
parser.add_argument("--visualization", default=False, action="store_true")
args = parser.parse_args()
......@@ -71,7 +70,6 @@ if __name__ == "__main__":
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001)
trainer = enas.EnasTrainer(model,
loss=criterion,
metrics=accuracy,
......
......@@ -62,7 +62,7 @@ class MicroNetwork(nn.Module):
reduction = False
if layer_id in pool_layers:
c_cur, reduction = c_p * 2, True
self.layers.append(ENASMicroLayer(self.layers, num_nodes, c_pp, c_p, c_cur, reduction))
self.layers.append(ENASMicroLayer(num_nodes, c_pp, c_p, c_cur, reduction))
if reduction:
c_pp = c_p = c_cur
c_pp, c_p = c_p, c_cur
......@@ -98,7 +98,6 @@ if __name__ == "__main__":
parser = ArgumentParser("enas")
parser.add_argument("--batch-size", default=128, type=int)
parser.add_argument("--log-frequency", default=10, type=int)
# parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)")
parser.add_argument("--visualization", default=False, action="store_true")
args = parser.parse_args()
......
......@@ -79,7 +79,6 @@ class ENASMicroLayer(nn.Module):
"""
def __init__(self, num_nodes, in_channels_pp, in_channels_p, out_channels, reduction):
super().__init__()
print(in_channels_pp, in_channels_p, out_channels, reduction)
self.reduction = reduction
if self.reduction:
self.reduce0 = FactorizedReduce(in_channels_pp, out_channels, affine=False)
......@@ -160,7 +159,7 @@ class ENASMacroLayer(mutables.MutableScope):
PoolBranch('avg', in_filters, out_filters, 3, 1, 1),
PoolBranch('max', in_filters, out_filters, 3, 1, 1)
])
if prev_labels > 0:
if prev_labels:
self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None)
else:
self.skipconnect = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment