Commit 73b2221b authored by Yuge Zhang's avatar Yuge Zhang Committed by QuanluZhang
Browse files

Update DARTS trainer and fix docstring issues (#1772)

parent 6d6f9524
data data
checkpoints checkpoints
runs
...@@ -48,7 +48,7 @@ class Node(nn.Module): ...@@ -48,7 +48,7 @@ class Node(nn.Module):
ops.SepConv(channels, channels, 3, stride, 1, affine=False), ops.SepConv(channels, channels, 3, stride, 1, affine=False),
ops.SepConv(channels, channels, 5, stride, 2, affine=False), ops.SepConv(channels, channels, 5, stride, 2, affine=False),
ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False), ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False),
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False), ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False)
], ],
key=choice_keys[-1])) key=choice_keys[-1]))
self.drop_path = ops.DropPath_() self.drop_path = ops.DropPath_()
...@@ -57,6 +57,7 @@ class Node(nn.Module): ...@@ -57,6 +57,7 @@ class Node(nn.Module):
def forward(self, prev_nodes): def forward(self, prev_nodes):
assert len(self.ops) == len(prev_nodes) assert len(self.ops) == len(prev_nodes)
out = [op(node) for op, node in zip(self.ops, prev_nodes)] out = [op(node) for op, node in zip(self.ops, prev_nodes)]
out = [self.drop_path(o) if o is not None else None for o in out]
return self.input_switch(out) return self.input_switch(out)
......
...@@ -4,9 +4,13 @@ import torch.nn as nn ...@@ -4,9 +4,13 @@ import torch.nn as nn
class DropPath_(nn.Module): class DropPath_(nn.Module):
def __init__(self, p=0.): def __init__(self, p=0.):
""" [!] DropPath is inplace module """
Args: DropPath is inplace module.
p: probability of an path to be zeroed.
Parameters
----------
p : float
Probability of an path to be zeroed.
""" """
super().__init__() super().__init__()
self.p = p self.p = p
...@@ -26,13 +30,9 @@ class DropPath_(nn.Module): ...@@ -26,13 +30,9 @@ class DropPath_(nn.Module):
class PoolBN(nn.Module): class PoolBN(nn.Module):
""" """
AvgPool or MaxPool - BN AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`.
""" """
def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True): def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
"""
Args:
pool_type: 'max' or 'avg'
"""
super().__init__() super().__init__()
if pool_type.lower() == 'max': if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding) self.pool = nn.MaxPool2d(kernel_size, stride, padding)
...@@ -50,8 +50,8 @@ class PoolBN(nn.Module): ...@@ -50,8 +50,8 @@ class PoolBN(nn.Module):
class StdConv(nn.Module): class StdConv(nn.Module):
""" Standard conv """
ReLU - Conv - BN Standard conv: ReLU - Conv - BN
""" """
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__() super().__init__()
...@@ -66,8 +66,8 @@ class StdConv(nn.Module): ...@@ -66,8 +66,8 @@ class StdConv(nn.Module):
class FacConv(nn.Module): class FacConv(nn.Module):
""" Factorized conv """
ReLU - Conv(Kx1) - Conv(1xK) - BN Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
""" """
def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True): def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
super().__init__() super().__init__()
...@@ -83,10 +83,10 @@ class FacConv(nn.Module): ...@@ -83,10 +83,10 @@ class FacConv(nn.Module):
class DilConv(nn.Module): class DilConv(nn.Module):
""" (Dilated) depthwise separable conv """
ReLU - (Dilated) depthwise separable - Pointwise - BN (Dilated) depthwise separable conv.
If dilation == 2, 3x3 conv => 5x5 receptive field ReLU - (Dilated) depthwise separable - Pointwise - BN.
5x5 conv => 9x9 receptive field If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field.
""" """
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super().__init__() super().__init__()
...@@ -103,8 +103,9 @@ class DilConv(nn.Module): ...@@ -103,8 +103,9 @@ class DilConv(nn.Module):
class SepConv(nn.Module): class SepConv(nn.Module):
""" Depthwise separable conv """
DilConv(dilation=1) * 2 Depthwise separable conv.
DilConv(dilation=1) * 2.
""" """
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__() super().__init__()
...@@ -119,7 +120,7 @@ class SepConv(nn.Module): ...@@ -119,7 +120,7 @@ class SepConv(nn.Module):
class FactorizedReduce(nn.Module): class FactorizedReduce(nn.Module):
""" """
Reduce feature map size by factorized pointwise(stride=2). Reduce feature map size by factorized pointwise (stride=2).
""" """
def __init__(self, C_in, C_out, affine=True): def __init__(self, C_in, C_out, affine=True):
super().__init__() super().__init__()
......
...@@ -4,12 +4,13 @@ from argparse import ArgumentParser ...@@ -4,12 +4,13 @@ from argparse import ArgumentParser
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.pytorch.utils import AverageMeter
from torch.utils.tensorboard import SummaryWriter
import datasets import datasets
import utils import utils
from model import CNN from model import CNN
from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.pytorch.utils import AverageMeter
logger = logging.getLogger() logger = logging.getLogger()
...@@ -23,6 +24,7 @@ logger.setLevel(logging.INFO) ...@@ -23,6 +24,7 @@ logger.setLevel(logging.INFO)
logger.addHandler(std_out_info) logger.addHandler(std_out_info)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
writer = SummaryWriter()
def train(config, train_loader, model, optimizer, criterion, epoch): def train(config, train_loader, model, optimizer, criterion, epoch):
...@@ -33,6 +35,7 @@ def train(config, train_loader, model, optimizer, criterion, epoch): ...@@ -33,6 +35,7 @@ def train(config, train_loader, model, optimizer, criterion, epoch):
cur_step = epoch * len(train_loader) cur_step = epoch * len(train_loader)
cur_lr = optimizer.param_groups[0]['lr'] cur_lr = optimizer.param_groups[0]['lr']
logger.info("Epoch %d LR %.6f", epoch, cur_lr) logger.info("Epoch %d LR %.6f", epoch, cur_lr)
writer.add_scalar("lr", cur_lr, global_step=cur_step)
model.train() model.train()
...@@ -54,6 +57,9 @@ def train(config, train_loader, model, optimizer, criterion, epoch): ...@@ -54,6 +57,9 @@ def train(config, train_loader, model, optimizer, criterion, epoch):
losses.update(loss.item(), bs) losses.update(loss.item(), bs)
top1.update(accuracy["acc1"], bs) top1.update(accuracy["acc1"], bs)
top5.update(accuracy["acc5"], bs) top5.update(accuracy["acc5"], bs)
writer.add_scalar("loss/train", loss.item(), global_step=cur_step)
writer.add_scalar("acc1/train", accuracy["acc1"], global_step=cur_step)
writer.add_scalar("acc5/train", accuracy["acc5"], global_step=cur_step)
if step % config.log_frequency == 0 or step == len(train_loader) - 1: if step % config.log_frequency == 0 or step == len(train_loader) - 1:
logger.info( logger.info(
...@@ -77,15 +83,15 @@ def validate(config, valid_loader, model, criterion, epoch, cur_step): ...@@ -77,15 +83,15 @@ def validate(config, valid_loader, model, criterion, epoch, cur_step):
with torch.no_grad(): with torch.no_grad():
for step, (X, y) in enumerate(valid_loader): for step, (X, y) in enumerate(valid_loader):
X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True) X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
N = X.size(0) bs = X.size(0)
logits = model(X) logits = model(X)
loss = criterion(logits, y) loss = criterion(logits, y)
accuracy = utils.accuracy(logits, y, topk=(1, 5)) accuracy = utils.accuracy(logits, y, topk=(1, 5))
losses.update(loss.item(), N) losses.update(loss.item(), bs)
top1.update(accuracy["acc1"], N) top1.update(accuracy["acc1"], bs)
top5.update(accuracy["acc5"], N) top5.update(accuracy["acc5"], bs)
if step % config.log_frequency == 0 or step == len(valid_loader) - 1: if step % config.log_frequency == 0 or step == len(valid_loader) - 1:
logger.info( logger.info(
...@@ -94,6 +100,10 @@ def validate(config, valid_loader, model, criterion, epoch, cur_step): ...@@ -94,6 +100,10 @@ def validate(config, valid_loader, model, criterion, epoch, cur_step):
epoch + 1, config.epochs, step, len(valid_loader) - 1, losses=losses, epoch + 1, config.epochs, step, len(valid_loader) - 1, losses=losses,
top1=top1, top5=top5)) top1=top1, top5=top5))
writer.add_scalar("loss/test", losses.avg, global_step=cur_step)
writer.add_scalar("acc1/test", top1.avg, global_step=cur_step)
writer.add_scalar("acc5/test", top5.avg, global_step=cur_step)
logger.info("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, config.epochs, top1.avg)) logger.info("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, config.epochs, top1.avg))
return top1.avg return top1.avg
......
...@@ -7,8 +7,7 @@ import torch.nn as nn ...@@ -7,8 +7,7 @@ import torch.nn as nn
import datasets import datasets
from model import CNN from model import CNN
from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint, from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
LearningRateScheduler)
from nni.nas.pytorch.darts import DartsTrainer from nni.nas.pytorch.darts import DartsTrainer
from utils import accuracy from utils import accuracy
...@@ -29,6 +28,7 @@ if __name__ == "__main__": ...@@ -29,6 +28,7 @@ if __name__ == "__main__":
parser.add_argument("--batch-size", default=64, type=int) parser.add_argument("--batch-size", default=64, type=int)
parser.add_argument("--log-frequency", default=10, type=int) parser.add_argument("--log-frequency", default=10, type=int)
parser.add_argument("--epochs", default=50, type=int) parser.add_argument("--epochs", default=50, type=int)
parser.add_argument("--unrolled", default=False, action="store_true")
args = parser.parse_args() args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10") dataset_train, dataset_valid = datasets.get_dataset("cifar10")
...@@ -48,5 +48,6 @@ if __name__ == "__main__": ...@@ -48,5 +48,6 @@ if __name__ == "__main__":
dataset_valid=dataset_valid, dataset_valid=dataset_valid,
batch_size=args.batch_size, batch_size=args.batch_size,
log_frequency=args.log_frequency, log_frequency=args.log_frequency,
callbacks=[LearningRateScheduler(lr_scheduler), ArchitectureCheckpoint("./checkpoints")]) unrolled=args.unrolled,
callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
trainer.train() trainer.train()
...@@ -19,7 +19,7 @@ class ENASLayer(mutables.MutableScope): ...@@ -19,7 +19,7 @@ class ENASLayer(mutables.MutableScope):
PoolBranch('max', in_filters, out_filters, 3, 1, 1) PoolBranch('max', in_filters, out_filters, 3, 1, 1)
]) ])
if len(prev_labels) > 0: if len(prev_labels) > 0:
self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None, reduction="sum") self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None)
else: else:
self.skipconnect = None self.skipconnect = None
self.batch_norm = nn.BatchNorm2d(out_filters, affine=False) self.batch_norm = nn.BatchNorm2d(out_filters, affine=False)
......
...@@ -9,7 +9,7 @@ import datasets ...@@ -9,7 +9,7 @@ import datasets
from macro import GeneralNetwork from macro import GeneralNetwork
from micro import MicroNetwork from micro import MicroNetwork
from nni.nas.pytorch import enas from nni.nas.pytorch import enas
from nni.nas.pytorch.callbacks import LearningRateScheduler, ArchitectureCheckpoint from nni.nas.pytorch.callbacks import LRSchedulerCallback, ArchitectureCheckpoint
from utils import accuracy, reward_accuracy from utils import accuracy, reward_accuracy
logger = logging.getLogger() logger = logging.getLogger()
...@@ -51,7 +51,7 @@ if __name__ == "__main__": ...@@ -51,7 +51,7 @@ if __name__ == "__main__":
metrics=accuracy, metrics=accuracy,
reward_function=reward_accuracy, reward_function=reward_accuracy,
optimizer=optimizer, optimizer=optimizer,
callbacks=[LearningRateScheduler(lr_scheduler), ArchitectureCheckpoint("./checkpoints")], callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")],
batch_size=args.batch_size, batch_size=args.batch_size,
num_epochs=num_epochs, num_epochs=num_epochs,
dataset_train=dataset_train, dataset_train=dataset_train,
......
...@@ -51,21 +51,22 @@ class BaseMutator(nn.Module): ...@@ -51,21 +51,22 @@ class BaseMutator(nn.Module):
def mutables(self): def mutables(self):
return self._structured_mutables return self._structured_mutables
@property
def forward(self, *inputs): def forward(self, *inputs):
raise RuntimeError("Forward is undefined for mutators.") raise RuntimeError("Forward is undefined for mutators.")
def __setattr__(self, name, value):
if name == "model":
raise AttributeError("Attribute `model` can be set at most once, and you shouldn't use `self.model = model` to "
"include you network, as it will include all parameters in model into the mutator.")
return super().__setattr__(name, value)
def enter_mutable_scope(self, mutable_scope): def enter_mutable_scope(self, mutable_scope):
""" """
Callback when forward of a MutableScope is entered. Callback when forward of a MutableScope is entered.
Parameters Parameters
---------- ----------
mutable_scope: MutableScope mutable_scope : MutableScope
Returns
-------
None
""" """
pass pass
...@@ -75,11 +76,7 @@ class BaseMutator(nn.Module): ...@@ -75,11 +76,7 @@ class BaseMutator(nn.Module):
Parameters Parameters
---------- ----------
mutable_scope: MutableScope mutable_scope : MutableScope
Returns
-------
None
""" """
pass pass
...@@ -89,8 +86,8 @@ class BaseMutator(nn.Module): ...@@ -89,8 +86,8 @@ class BaseMutator(nn.Module):
Parameters Parameters
---------- ----------
mutable: LayerChoice mutable : LayerChoice
inputs: list of torch.Tensor inputs : list of torch.Tensor
Returns Returns
------- -------
...@@ -105,8 +102,8 @@ class BaseMutator(nn.Module): ...@@ -105,8 +102,8 @@ class BaseMutator(nn.Module):
Parameters Parameters
---------- ----------
mutable: InputChoice mutable : InputChoice
tensor_list: list of torch.Tensor tensor_list : list of torch.Tensor
Returns Returns
------- -------
......
...@@ -29,7 +29,7 @@ class Callback: ...@@ -29,7 +29,7 @@ class Callback:
pass pass
class LearningRateScheduler(Callback): class LRSchedulerCallback(Callback):
def __init__(self, scheduler, mode="epoch"): def __init__(self, scheduler, mode="epoch"):
super().__init__() super().__init__()
assert mode == "epoch" assert mode == "epoch"
......
import torch import torch
from torch import nn as nn import torch.nn as nn
from torch.nn import functional as F import torch.nn.functional as F
from nni.nas.pytorch.mutator import Mutator from nni.nas.pytorch.mutator import Mutator
from nni.nas.pytorch.mutables import LayerChoice, InputChoice from nni.nas.pytorch.mutables import LayerChoice, InputChoice
......
...@@ -2,27 +2,27 @@ import copy ...@@ -2,27 +2,27 @@ import copy
import logging import logging
import torch import torch
from torch import nn as nn import torch.nn as nn
from nni.nas.pytorch.trainer import Trainer from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup from nni.nas.pytorch.utils import AverageMeterGroup
from .mutator import DartsMutator from .mutator import DartsMutator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class DartsTrainer(Trainer): class DartsTrainer(Trainer):
def __init__(self, model, loss, metrics, def __init__(self, model, loss, metrics,
optimizer, num_epochs, dataset_train, dataset_valid, optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
callbacks=None): callbacks=None, arc_learning_rate=3.0E-4, unrolled=True):
super().__init__(model, mutator if mutator is not None else DartsMutator(model), super().__init__(model, mutator if mutator is not None else DartsMutator(model),
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
batch_size, workers, device, log_frequency, callbacks) batch_size, workers, device, log_frequency, callbacks)
self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), 3.0E-4, betas=(0.5, 0.999), self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), arc_learning_rate, betas=(0.5, 0.999),
weight_decay=1.0E-3) weight_decay=1.0E-3)
self.unrolled = unrolled
n_train = len(self.dataset_train) n_train = len(self.dataset_train)
split = n_train // 2 split = n_train // 2
indices = list(range(n_train)) indices = list(range(n_train))
...@@ -43,42 +43,32 @@ class DartsTrainer(Trainer): ...@@ -43,42 +43,32 @@ class DartsTrainer(Trainer):
def train_one_epoch(self, epoch): def train_one_epoch(self, epoch):
self.model.train() self.model.train()
self.mutator.train() self.mutator.train()
lr = self.optimizer.param_groups[0]["lr"]
meters = AverageMeterGroup() meters = AverageMeterGroup()
for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)): for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)):
trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device) trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device)
val_X, val_y = val_X.to(self.device), val_y.to(self.device) val_X, val_y = val_X.to(self.device), val_y.to(self.device)
# backup model for hessian # phase 1. architecture step
backup_model = copy.deepcopy(self.model.state_dict()) self.ctrl_optim.zero_grad()
# cannot deepcopy model because it will break the reference if self.unrolled:
self._unrolled_backward(trn_X, trn_y, val_X, val_y)
else:
self._backward(val_X, val_y)
self.ctrl_optim.step()
# phase 1. child network step # phase 2: child network step
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.mutator.reset() logits, loss = self._logits_and_loss(trn_X, trn_y)
logits = self.model(trn_X)
loss = self.loss(logits, trn_y)
loss.backward() loss.backward()
# gradient clipping nn.utils.clip_grad_norm_(self.model.parameters(), 5.) # gradient clipping
nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
self.optimizer.step() self.optimizer.step()
new_model = copy.deepcopy(self.model.state_dict())
# phase 2. architect step (alpha)
self.ctrl_optim.zero_grad()
# compute unrolled loss
self._unrolled_backward(trn_X, trn_y, val_X, val_y, backup_model, lr)
self.ctrl_optim.step()
self.model.load_state_dict(new_model)
metrics = self.metrics(logits, trn_y) metrics = self.metrics(logits, trn_y)
metrics["loss"] = loss.item() metrics["loss"] = loss.item()
meters.update(metrics) meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0: if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch+1, logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step+1, len(self.train_loader), meters) self.num_epochs, step + 1, len(self.train_loader), meters)
def validate_one_epoch(self, epoch): def validate_one_epoch(self, epoch):
self.model.eval() self.model.eval()
...@@ -92,55 +82,92 @@ class DartsTrainer(Trainer): ...@@ -92,55 +82,92 @@ class DartsTrainer(Trainer):
metrics = self.metrics(logits, y) metrics = self.metrics(logits, y)
meters.update(metrics) meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0: if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch+1, logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step+1, len(self.test_loader), meters) self.num_epochs, step + 1, len(self.test_loader), meters)
def _logits_and_loss(self, X, y):
self.mutator.reset()
logits = self.model(X)
loss = self.loss(logits, y)
return logits, loss
def _unrolled_backward(self, trn_X, trn_y, val_X, val_y, backup_model, lr): def _backward(self, val_X, val_y):
"""
Simple backward with gradient descent
"""
_, loss = self._logits_and_loss(val_X, val_y)
loss.backward()
def _unrolled_backward(self, trn_X, trn_y, val_X, val_y):
""" """
Compute unrolled loss and backward its gradients Compute unrolled loss and backward its gradients
Parameters
----------
v_model: backup model before this step
lr: learning rate for virtual gradient step (same as net lr)
""" """
self.mutator.reset() backup_params = copy.deepcopy(tuple(self.model.parameters()))
loss = self.loss(self.model(val_X), val_y)
w_model = tuple(self.model.parameters()) # do virtual step on training data
w_ctrl = tuple(self.mutator.parameters()) lr = self.optimizer.param_groups[0]["lr"]
momentum = self.optimizer.param_groups[0]["momentum"]
weight_decay = self.optimizer.param_groups[0]["weight_decay"]
self._compute_virtual_model(trn_X, trn_y, lr, momentum, weight_decay)
# calculate unrolled loss on validation data
# keep gradients for model here for compute hessian
_, loss = self._logits_and_loss(val_X, val_y)
w_model, w_ctrl = tuple(self.model.parameters()), tuple(self.mutator.parameters())
w_grads = torch.autograd.grad(loss, w_model + w_ctrl) w_grads = torch.autograd.grad(loss, w_model + w_ctrl)
d_model = w_grads[:len(w_model)] d_model, d_ctrl = w_grads[:len(w_model)], w_grads[len(w_model):]
d_ctrl = w_grads[len(w_model):]
hessian = self._compute_hessian(backup_model, d_model, trn_X, trn_y) # compute hessian and final gradients
hessian = self._compute_hessian(backup_params, d_model, trn_X, trn_y)
with torch.no_grad(): with torch.no_grad():
for param, d, h in zip(w_ctrl, d_ctrl, hessian): for param, d, h in zip(w_ctrl, d_ctrl, hessian):
# gradient = dalpha - lr * hessian
param.grad = d - lr * h param.grad = d - lr * h
def _compute_hessian(self, model, dw, trn_X, trn_y): # restore weights
self._restore_weights(backup_params)
def _compute_virtual_model(self, X, y, lr, momentum, weight_decay):
""" """
dw = dw` { L_val(w`, alpha) } Compute unrolled weights w`
w+ = w + eps * dw
w- = w - eps * dw
hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps)
eps = 0.01 / ||dw||
""" """
self.model.load_state_dict(model) # don't need zero_grad, using autograd to calculate gradients
_, loss = self._logits_and_loss(X, y)
gradients = torch.autograd.grad(loss, self.model.parameters())
with torch.no_grad():
for w, g in zip(self.model.parameters(), gradients):
m = self.optimizer.state[w].get("momentum_buffer", 0.)
w = w - lr * (momentum * m + g + weight_decay * w)
def _restore_weights(self, backup_params):
with torch.no_grad():
for param, backup in zip(self.model.parameters(), backup_params):
param.copy_(backup)
def _compute_hessian(self, backup_params, dw, trn_X, trn_y):
"""
dw = dw` { L_val(w`, alpha) }
w+ = w + eps * dw
w- = w - eps * dw
hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps)
eps = 0.01 / ||dw||
"""
self._restore_weights(backup_params)
norm = torch.cat([w.view(-1) for w in dw]).norm() norm = torch.cat([w.view(-1) for w in dw]).norm()
eps = 0.01 / norm eps = 0.01 / norm
if norm < 1E-8:
logger.warning("In computing hessian, norm is smaller than 1E-8, cause eps to be %.6f.", norm.item())
dalphas = []
for e in [eps, -2. * eps]: for e in [eps, -2. * eps]:
# w+ = w + eps*dw`, w- = w - eps*dw` # w+ = w + eps*dw`, w- = w - eps*dw`
with torch.no_grad(): with torch.no_grad():
for p, d in zip(self.model.parameters(), dw): for p, d in zip(self.model.parameters(), dw):
p += eps * d p += e * d
self.mutator.reset() _, loss = self._logits_and_loss(trn_X, trn_y)
loss = self.loss(self.model(trn_X), trn_y) dalphas.append(torch.autograd.grad(loss, self.mutator.parameters()))
if e > 0:
dalpha_pos = torch.autograd.grad(loss, self.mutator.parameters()) # dalpha { L_trn(w+) }
elif e < 0:
dalpha_neg = torch.autograd.grad(loss, self.mutator.parameters()) # dalpha { L_trn(w-) }
dalpha_pos, dalpha_neg = dalphas # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
hessian = [(p - n) / 2. * eps for p, n in zip(dalpha_pos, dalpha_neg)] hessian = [(p - n) / 2. * eps for p, n in zip(dalpha_pos, dalpha_neg)]
return hessian return hessian
...@@ -25,6 +25,7 @@ class StackedLSTMCell(nn.Module): ...@@ -25,6 +25,7 @@ class StackedLSTMCell(nn.Module):
class EnasMutator(Mutator): class EnasMutator(Mutator):
def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False, def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False,
skip_target=0.4, branch_bias=0.25): skip_target=0.4, branch_bias=0.25):
super().__init__(model) super().__init__(model)
...@@ -51,7 +52,7 @@ class EnasMutator(Mutator): ...@@ -51,7 +52,7 @@ class EnasMutator(Mutator):
self.max_layer_choice = mutable.length self.max_layer_choice = mutable.length
assert self.max_layer_choice == mutable.length, \ assert self.max_layer_choice == mutable.length, \
"ENAS mutator requires all layer choice have the same number of candidates." "ENAS mutator requires all layer choice have the same number of candidates."
# NOTE(yuge): We might implement an interface later. Judging by key now. # We are judging by keys and module types to add biases to layer choices. Needs refactor.
if "reduce" in mutable.key: if "reduce" in mutable.key:
def is_conv(choice): def is_conv(choice):
return "conv" in str(type(choice)).lower() return "conv" in str(type(choice)).lower()
......
...@@ -6,9 +6,7 @@ from nni.nas.pytorch.trainer import Trainer ...@@ -6,9 +6,7 @@ from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup from nni.nas.pytorch.utils import AverageMeterGroup
from .mutator import EnasMutator from .mutator import EnasMutator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class EnasTrainer(Trainer): class EnasTrainer(Trainer):
...@@ -75,8 +73,8 @@ class EnasTrainer(Trainer): ...@@ -75,8 +73,8 @@ class EnasTrainer(Trainer):
meters.update(metrics) meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0: if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Model Epoch [%s/%s] Step [%s/%s] %s", epoch, logger.info("Model Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step, len(self.train_loader), meters) self.num_epochs, step + 1, len(self.train_loader), meters)
# Train sampler (mutator) # Train sampler (mutator)
self.model.eval() self.model.eval()
...@@ -114,8 +112,8 @@ class EnasTrainer(Trainer): ...@@ -114,8 +112,8 @@ class EnasTrainer(Trainer):
self.mutator_optim.zero_grad() self.mutator_optim.zero_grad()
if self.log_frequency is not None and step % self.log_frequency == 0: if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("RL Epoch [%s/%s] Step [%s/%s] %s", epoch, self.num_epochs, logger.info("RL Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, self.num_epochs,
mutator_step // self.mutator_steps_aggregate, self.mutator_steps, meters) mutator_step // self.mutator_steps_aggregate + 1, self.mutator_steps, meters)
mutator_step += 1 mutator_step += 1
if mutator_step >= total_mutator_steps: if mutator_step >= total_mutator_steps:
break break
......
...@@ -14,11 +14,11 @@ class FixedArchitecture(Mutator): ...@@ -14,11 +14,11 @@ class FixedArchitecture(Mutator):
Parameters Parameters
---------- ----------
model: nn.Module model : nn.Module
A mutable network. A mutable network.
fixed_arc: str or dict fixed_arc : str or dict
Path to the architecture checkpoint (a string), or preloaded architecture object (a dict). Path to the architecture checkpoint (a string), or preloaded architecture object (a dict).
strict: bool strict : bool
Force everything that appears in `fixed_arc` to be used at least once. Force everything that appears in `fixed_arc` to be used at least once.
""" """
super().__init__(model) super().__init__(model)
...@@ -55,11 +55,11 @@ def apply_fixed_architecture(model, fixed_arc_path, device=None): ...@@ -55,11 +55,11 @@ def apply_fixed_architecture(model, fixed_arc_path, device=None):
Parameters Parameters
---------- ----------
model: torch.nn.Module model : torch.nn.Module
Model with mutables. Model with mutables.
fixed_arc_path: str fixed_arc_path : str
Path to the JSON that stores the architecture. Path to the JSON that stores the architecture.
device: torch.device device : torch.device
Architecture weights will be transfered to `device`. Architecture weights will be transfered to `device`.
Returns Returns
...@@ -76,3 +76,4 @@ def apply_fixed_architecture(model, fixed_arc_path, device=None): ...@@ -76,3 +76,4 @@ def apply_fixed_architecture(model, fixed_arc_path, device=None):
architecture = FixedArchitecture(model, fixed_arc) architecture = FixedArchitecture(model, fixed_arc)
architecture.to(device) architecture.to(device)
architecture.reset() architecture.reset()
return architecture
...@@ -39,6 +39,9 @@ class Mutable(nn.Module): ...@@ -39,6 +39,9 @@ class Mutable(nn.Module):
return super().__call__(*args, **kwargs) return super().__call__(*args, **kwargs)
def set_mutator(self, mutator): def set_mutator(self, mutator):
if "mutator" in self.__dict__:
raise RuntimeError("`set_mutator` is called more than once. Did you parse the search space multiple times? "
"Or did you apply multiple fixed architectures?")
self.__dict__["mutator"] = mutator self.__dict__["mutator"] = mutator
def forward(self, *inputs): def forward(self, *inputs):
...@@ -68,9 +71,10 @@ class Mutable(nn.Module): ...@@ -68,9 +71,10 @@ class Mutable(nn.Module):
class MutableScope(Mutable): class MutableScope(Mutable):
""" """
Mutable scope labels a subgraph/submodule to help mutators make better decisions. Mutable scope marks a subgraph/submodule to help mutators make better decisions.
Mutators get notified when a mutable scope is entered and exited. Mutators can override ``enter_mutable_scope`` Mutators get notified when a mutable scope is entered and exited. Mutators can override ``enter_mutable_scope``
and ``exit_mutable_scope`` to catch corresponding events, and do status dump or update. and ``exit_mutable_scope`` to catch corresponding events, and do status dump or update.
MutableScope are also mutables that are listed in the mutables (search space).
""" """
def __init__(self, key): def __init__(self, key):
...@@ -86,7 +90,7 @@ class MutableScope(Mutable): ...@@ -86,7 +90,7 @@ class MutableScope(Mutable):
class LayerChoice(Mutable): class LayerChoice(Mutable):
def __init__(self, op_candidates, reduction="mean", return_mask=False, key=None): def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None):
super().__init__(key=key) super().__init__(key=key)
self.length = len(op_candidates) self.length = len(op_candidates)
self.choices = nn.ModuleList(op_candidates) self.choices = nn.ModuleList(op_candidates)
...@@ -117,25 +121,25 @@ class InputChoice(Mutable): ...@@ -117,25 +121,25 @@ class InputChoice(Mutable):
NO_KEY = "" NO_KEY = ""
def __init__(self, n_candidates=None, choose_from=None, n_chosen=None, def __init__(self, n_candidates=None, choose_from=None, n_chosen=None,
reduction="mean", return_mask=False, key=None): reduction="sum", return_mask=False, key=None):
""" """
Initialization. Initialization.
Parameters Parameters
---------- ----------
n_candidates: int n_candidates : int
Number of inputs to choose from. Number of inputs to choose from.
choose_from: list of str choose_from : list of str
List of source keys to choose from. At least of one of `choose_from` and `n_candidates` must be fulfilled. List of source keys to choose from. At least of one of `choose_from` and `n_candidates` must be fulfilled.
If `n_candidates` has a value but `choose_from` is None, it will be automatically treated as `n_candidates` If `n_candidates` has a value but `choose_from` is None, it will be automatically treated as `n_candidates`
number of empty string. number of empty string.
n_chosen: int n_chosen : int
Recommended inputs to choose. If None, mutator is instructed to select any. Recommended inputs to choose. If None, mutator is instructed to select any.
reduction: str reduction : str
`mean`, `concat`, `sum` or `none`. `mean`, `concat`, `sum` or `none`.
return_mask: bool return_mask : bool
If `return_mask`, return output tensor and a mask. Otherwise return tensor only. If `return_mask`, return output tensor and a mask. Otherwise return tensor only.
key: str key : str
Key of the input choice. Key of the input choice.
""" """
super().__init__(key=key) super().__init__(key=key)
...@@ -163,7 +167,7 @@ class InputChoice(Mutable): ...@@ -163,7 +167,7 @@ class InputChoice(Mutable):
Parameters Parameters
---------- ----------
optional_inputs: list or dict optional_inputs : list or dict
Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of
`choose_from` in initialization. As a list, inputs must follow the semantic order that is the same as `choose_from` in initialization. As a list, inputs must follow the semantic order that is the same as
`choose_from`. `choose_from`.
......
import logging
import torch import torch
from nni.nas.pytorch.base_mutator import BaseMutator from nni.nas.pytorch.base_mutator import BaseMutator
logger = logging.getLogger(__name__)
class Mutator(BaseMutator): class Mutator(BaseMutator):
...@@ -60,8 +64,8 @@ class Mutator(BaseMutator): ...@@ -60,8 +64,8 @@ class Mutator(BaseMutator):
Parameters Parameters
---------- ----------
mutable: LayerChoice mutable : LayerChoice
inputs: list of torch.Tensor inputs : list of torch.Tensor
Returns Returns
------- -------
...@@ -85,9 +89,9 @@ class Mutator(BaseMutator): ...@@ -85,9 +89,9 @@ class Mutator(BaseMutator):
Parameters Parameters
---------- ----------
mutable: InputChoice mutable : InputChoice
tensor_list: list of torch.Tensor tensor_list : list of torch.Tensor
tags: list of string tags : list of string
Returns Returns
------- -------
...@@ -108,7 +112,7 @@ class Mutator(BaseMutator): ...@@ -108,7 +112,7 @@ class Mutator(BaseMutator):
return out return out
def _tensor_reduction(self, reduction_type, tensor_list): def _tensor_reduction(self, reduction_type, tensor_list):
if tensor_list == "none": if reduction_type == "none":
return tensor_list return tensor_list
if not tensor_list: if not tensor_list:
return None # empty. return None for now return None # empty. return None for now
...@@ -129,12 +133,14 @@ class Mutator(BaseMutator): ...@@ -129,12 +133,14 @@ class Mutator(BaseMutator):
Parameters Parameters
---------- ----------
mutable: Mutable mutable : Mutable
Returns Returns
------- -------
any object
""" """
if mutable.key not in self._cache: if mutable.key not in self._cache:
raise ValueError("\"{}\" not found in decision cache.".format(mutable.key)) raise ValueError("\"{}\" not found in decision cache.".format(mutable.key))
return self._cache[mutable.key] result = self._cache[mutable.key]
logger.debug("Decision %s: %s", mutable.key, result)
return result
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import copy import copy
import numpy as np import numpy as np
from torch.nn import functional as F import torch.nn.functional as F
from nni.nas.pytorch.darts import DartsMutator from nni.nas.pytorch.darts import DartsMutator
from nni.nas.pytorch.mutables import LayerChoice from nni.nas.pytorch.mutables import LayerChoice
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import logging import logging
from nni.nas.pytorch.callbacks import LearningRateScheduler from nni.nas.pytorch.callbacks import LRSchedulerCallback
from nni.nas.pytorch.darts import DartsTrainer from nni.nas.pytorch.darts import DartsTrainer
from nni.nas.pytorch.trainer import BaseTrainer from nni.nas.pytorch.trainer import BaseTrainer
...@@ -50,7 +50,7 @@ class PdartsTrainer(BaseTrainer): ...@@ -50,7 +50,7 @@ class PdartsTrainer(BaseTrainer):
darts_callbacks = [] darts_callbacks = []
if lr_scheduler is not None: if lr_scheduler is not None:
darts_callbacks.append(LearningRateScheduler(lr_scheduler)) darts_callbacks.append(LRSchedulerCallback(lr_scheduler))
self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion, optimizer=optim, self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion, optimizer=optim,
callbacks=darts_callbacks, **self.darts_parameters) callbacks=darts_callbacks, **self.darts_parameters)
......
...@@ -24,6 +24,40 @@ class TorchTensorEncoder(json.JSONEncoder): ...@@ -24,6 +24,40 @@ class TorchTensorEncoder(json.JSONEncoder):
class Trainer(BaseTrainer): class Trainer(BaseTrainer):
def __init__(self, model, mutator, loss, metrics, optimizer, num_epochs, def __init__(self, model, mutator, loss, metrics, optimizer, num_epochs,
dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks): dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks):
"""
Trainer initialization.
Parameters
----------
model : nn.Module
Model with mutables.
mutator : BaseMutator
A mutator object that has been initialized with the model.
loss : callable
Called with logits and targets. Returns a loss tensor.
metrics : callable
Returns a dict that maps metrics keys to metrics data.
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
dataset_train : torch.utils.data.Dataset
Dataset of training.
dataset_valid : torch.utils.data.Dataset
Dataset of validation/testing.
batch_size : int
Batch size.
workers : int
Number of workers used in data preprocessing.
device : torch.device
Device object. Either `torch.device("cuda")` or torch.device("cpu")`. When `None`, trainer will
automatic detects GPU and selects GPU first.
log_frequency : int
Number of mini-batches to log metrics.
callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks.
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.model = model self.model = model
self.mutator = mutator self.mutator = mutator
......
...@@ -28,6 +28,16 @@ class AverageMeter: ...@@ -28,6 +28,16 @@ class AverageMeter:
"""Computes and stores the average and current value""" """Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'): def __init__(self, name, fmt=':f'):
"""
Initialization of AverageMeter
Parameters
----------
name : str
Name to display.
fmt : str
Format string to print the values.
"""
self.name = name self.name = name
self.fmt = fmt self.fmt = fmt
self.reset() self.reset()
...@@ -78,12 +88,12 @@ class StructuredMutableTreeNode: ...@@ -78,12 +88,12 @@ class StructuredMutableTreeNode:
Parameters Parameters
---------- ----------
order: str order : str
pre or post. If pre, current mutable is yield before children. Otherwise after. pre or post. If pre, current mutable is yield before children. Otherwise after.
deduplicate: bool deduplicate : bool
If true, mutables with the same key will not appear after the first appearance. If true, mutables with the same key will not appear after the first appearance.
memo: dict memo : dict
An auxiliary variable to make deduplicate happen. An auxiliary dict that memorize keys seen before, so that deduplication is possible.
Returns Returns
------- -------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment