Commit 5c96b82c authored by Chi Song's avatar Chi Song Committed by chicm-ms
Browse files

[NAS] fix bug on pdarts (#1797)

parent e9cba778
...@@ -27,11 +27,14 @@ if __name__ == "__main__": ...@@ -27,11 +27,14 @@ if __name__ == "__main__":
parser = ArgumentParser("pdarts") parser = ArgumentParser("pdarts")
parser.add_argument('--add_layers', action='append', parser.add_argument('--add_layers', action='append',
default=[0, 6, 12], help='add layers') default=[0, 6, 12], help='add layers')
parser.add_argument('--dropped_ops', action='append',
default=[3, 2, 1], help='drop ops')
parser.add_argument("--nodes", default=4, type=int) parser.add_argument("--nodes", default=4, type=int)
parser.add_argument("--layers", default=5, type=int) parser.add_argument("--init_layers", default=5, type=int)
parser.add_argument("--batch-size", default=64, type=int) parser.add_argument("--batch-size", default=64, type=int)
parser.add_argument("--log-frequency", default=1, type=int) parser.add_argument("--log-frequency", default=1, type=int)
parser.add_argument("--epochs", default=50, type=int) parser.add_argument("--epochs", default=50, type=int)
parser.add_argument("--unrolled", default=False, action="store_true")
args = parser.parse_args() args = parser.parse_args()
logger.info("loading data") logger.info("loading data")
...@@ -48,15 +51,16 @@ if __name__ == "__main__": ...@@ -48,15 +51,16 @@ if __name__ == "__main__":
logger.info("initializing trainer") logger.info("initializing trainer")
trainer = PdartsTrainer(model_creator, trainer = PdartsTrainer(model_creator,
layers=args.layers, init_layers=args.init_layers,
metrics=lambda output, target: accuracy(output, target, topk=(1,)), metrics=lambda output, target: accuracy(output, target, topk=(1,)),
pdarts_num_layers=[0, 6, 12], pdarts_num_layers=args.add_layers,
pdarts_num_to_drop=[3, 2, 2], pdarts_num_to_drop=args.dropped_ops,
num_epochs=args.epochs, num_epochs=args.epochs,
dataset_train=dataset_train, dataset_train=dataset_train,
dataset_valid=dataset_valid, dataset_valid=dataset_valid,
batch_size=args.batch_size, batch_size=args.batch_size,
log_frequency=args.log_frequency, log_frequency=args.log_frequency,
unrolled=args.unrolled,
callbacks=[ArchitectureCheckpoint("./checkpoints")]) callbacks=[ArchitectureCheckpoint("./checkpoints")])
logger.info("training") logger.info("training")
trainer.train() trainer.train()
...@@ -18,10 +18,11 @@ class DartsTrainer(Trainer): ...@@ -18,10 +18,11 @@ class DartsTrainer(Trainer):
def __init__(self, model, loss, metrics, def __init__(self, model, loss, metrics,
optimizer, num_epochs, dataset_train, dataset_valid, optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
callbacks=None, arc_learning_rate=3.0E-4, unrolled=True): callbacks=None, arc_learning_rate=3.0E-4, unrolled=False):
super().__init__(model, mutator if mutator is not None else DartsMutator(model), super().__init__(model, mutator if mutator is not None else DartsMutator(model),
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
batch_size, workers, device, log_frequency, callbacks) batch_size, workers, device, log_frequency, callbacks)
self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), arc_learning_rate, betas=(0.5, 0.999), self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), arc_learning_rate, betas=(0.5, 0.999),
weight_decay=1.0E-3) weight_decay=1.0E-3)
self.unrolled = unrolled self.unrolled = unrolled
......
...@@ -111,7 +111,7 @@ class Mutator(BaseMutator): ...@@ -111,7 +111,7 @@ class Mutator(BaseMutator):
if "BoolTensor" in mask.type(): if "BoolTensor" in mask.type():
out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m] out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m]
elif "FloatTensor" in mask.type(): elif "FloatTensor" in mask.type():
out = [map_fn(*cand) * m for cand, m in zip(candidates, mask)] out = [map_fn(*cand) * m for cand, m in zip(candidates, mask) if m]
else: else:
raise ValueError("Unrecognized mask") raise ValueError("Unrecognized mask")
return out return out
......
...@@ -4,13 +4,18 @@ ...@@ -4,13 +4,18 @@
import copy import copy
import numpy as np import numpy as np
import torch.nn.functional as F import torch
from torch import nn
from nni.nas.pytorch.darts import DartsMutator from nni.nas.pytorch.darts import DartsMutator
from nni.nas.pytorch.mutables import LayerChoice from nni.nas.pytorch.mutables import LayerChoice
class PdartsMutator(DartsMutator): class PdartsMutator(DartsMutator):
"""
It works with PdartsTrainer to calculate ops weights,
and drop weights in different PDARTS epochs.
"""
def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches={}): def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches={}):
self.pdarts_epoch_index = pdarts_epoch_index self.pdarts_epoch_index = pdarts_epoch_index
...@@ -22,60 +27,66 @@ class PdartsMutator(DartsMutator): ...@@ -22,60 +27,66 @@ class PdartsMutator(DartsMutator):
super(PdartsMutator, self).__init__(model) super(PdartsMutator, self).__init__(model)
# this loop go through mutables with different keys,
# it's mainly to update length of choices.
for mutable in self.mutables: for mutable in self.mutables:
if isinstance(mutable, LayerChoice): if isinstance(mutable, LayerChoice):
switches = self.switches.get(mutable.key, [True for j in range(mutable.length)]) switches = self.switches.get(mutable.key, [True for j in range(mutable.length)])
choices = self.choices[mutable.key]
for index in range(len(switches)-1, -1, -1): operations_count = np.sum(switches)
if switches[index] == False: # +1 and -1 are caused by zero operation in darts network
del(mutable.choices[index]) # the zero operation is not in choices list in network, but its weight are in,
mutable.length -= 1 # so it needs one more weights and switch for zero.
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(operations_count + 1))
self.switches[mutable.key] = switches self.switches[mutable.key] = switches
def drop_paths(self): # update LayerChoice instances in model,
for key in self.switches: # it's physically remove dropped choices operations.
prob = F.softmax(self.choices[key], dim=-1).data.cpu().numpy() for module in self.model.modules():
if isinstance(module, LayerChoice):
switches = self.switches.get(module.key)
choices = self.choices[module.key]
if len(module.choices) > len(choices):
# from last to first, so that it won't effect previous indexes after removed one.
for index in range(len(switches)-1, -1, -1):
if switches[index] == False:
del(module.choices[index])
module.length -= 1
def sample_final(self):
results = super().sample_final()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
# As some operations are dropped physically,
# so it needs to fill back false to track dropped operations.
trained_result = results[mutable.key]
trained_index = 0
switches = self.switches[mutable.key]
result = torch.Tensor(switches).bool()
for index in range(len(result)):
if result[index]:
result[index] = trained_result[trained_index]
trained_index += 1
results[mutable.key] = result
return results
switches = self.switches[key] def drop_paths(self):
"""
This method is called when a PDARTS epoch is finished.
It prepares switches for next epoch.
candidate operations with False switch will be doppped in next epoch.
"""
all_switches = copy.deepcopy(self.switches)
for key in all_switches:
switches = all_switches[key]
idxs = [] idxs = []
for j in range(len(switches)): for j in range(len(switches)):
if switches[j]: if switches[j]:
idxs.append(j) idxs.append(j)
if self.pdarts_epoch_index == len(self.pdarts_num_to_drop) - 1: sorted_weights = self.choices[key].data.cpu().numpy()[:-1]
# for the last stage, drop all Zero operations drop = np.argsort(sorted_weights)[:self.pdarts_num_to_drop[self.pdarts_epoch_index]]
drop = self.get_min_k_no_zero(prob, idxs, self.pdarts_num_to_drop[self.pdarts_epoch_index])
else:
drop = self.get_min_k(prob, self.pdarts_num_to_drop[self.pdarts_epoch_index])
for idx in drop: for idx in drop:
switches[idxs[idx]] = False switches[idxs[idx]] = False
return self.switches return all_switches
def get_min_k(self, input_in, k):
index = []
for _ in range(k):
idx = np.argmin(input)
index.append(idx)
return index
def get_min_k_no_zero(self, w_in, idxs, k):
w = copy.deepcopy(w_in)
index = []
if 0 in idxs:
zf = True
else:
zf = False
if zf:
w = w[1:]
index.append(0)
k = k - 1
for _ in range(k):
idx = np.argmin(w)
w[idx] = 1
if zf:
idx = idx + 1
index.append(idx)
return index
...@@ -14,14 +14,22 @@ logger = logging.getLogger(__name__) ...@@ -14,14 +14,22 @@ logger = logging.getLogger(__name__)
class PdartsTrainer(BaseTrainer): class PdartsTrainer(BaseTrainer):
"""
def __init__(self, model_creator, layers, metrics, This trainer implements the PDARTS algorithm.
PDARTS bases on DARTS algorithm, and provides a network growth approach to find deeper and better network.
This class relies on pdarts_num_layers and pdarts_num_to_drop parameters to control how network grows.
pdarts_num_layers means how many layers more than first epoch.
pdarts_num_to_drop means how many candidate operations should be dropped in each epoch.
So that the grew network can in similar size.
"""
def __init__(self, model_creator, init_layers, metrics,
num_epochs, dataset_train, dataset_valid, num_epochs, dataset_train, dataset_valid,
pdarts_num_layers=[0, 6, 12], pdarts_num_to_drop=[3, 2, 2], pdarts_num_layers=[0, 6, 12], pdarts_num_to_drop=[3, 2, 1],
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None): mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None, unrolled=False):
super(PdartsTrainer, self).__init__() super(PdartsTrainer, self).__init__()
self.model_creator = model_creator self.model_creator = model_creator
self.layers = layers self.init_layers = init_layers
self.pdarts_num_layers = pdarts_num_layers self.pdarts_num_layers = pdarts_num_layers
self.pdarts_num_to_drop = pdarts_num_to_drop self.pdarts_num_to_drop = pdarts_num_to_drop
self.pdarts_epoch = len(pdarts_num_to_drop) self.pdarts_epoch = len(pdarts_num_to_drop)
...@@ -33,16 +41,17 @@ class PdartsTrainer(BaseTrainer): ...@@ -33,16 +41,17 @@ class PdartsTrainer(BaseTrainer):
"batch_size": batch_size, "batch_size": batch_size,
"workers": workers, "workers": workers,
"device": device, "device": device,
"log_frequency": log_frequency "log_frequency": log_frequency,
"unrolled": unrolled
} }
self.callbacks = callbacks if callbacks is not None else [] self.callbacks = callbacks if callbacks is not None else []
def train(self): def train(self):
layers = self.layers
switches = None switches = None
for epoch in range(self.pdarts_epoch): for epoch in range(self.pdarts_epoch):
layers = self.layers+self.pdarts_num_layers[epoch] layers = self.init_layers+self.pdarts_num_layers[epoch]
model, criterion, optim, lr_scheduler = self.model_creator(layers) model, criterion, optim, lr_scheduler = self.model_creator(layers)
self.mutator = PdartsMutator(model, epoch, self.pdarts_num_to_drop, switches) self.mutator = PdartsMutator(model, epoch, self.pdarts_num_to_drop, switches)
...@@ -66,7 +75,7 @@ class PdartsTrainer(BaseTrainer): ...@@ -66,7 +75,7 @@ class PdartsTrainer(BaseTrainer):
callback.on_epoch_end(epoch) callback.on_epoch_end(epoch)
def validate(self): def validate(self):
self.model.validate() self.trainer.validate()
def export(self, file): def export(self, file):
mutator_export = self.mutator.export() mutator_export = self.mutator.export()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment