Commit 73b2221b authored by Yuge Zhang's avatar Yuge Zhang Committed by QuanluZhang
Browse files

Update DARTS trainer and fix docstring issues (#1772)

parent 6d6f9524
......@@ -48,7 +48,7 @@ class Node(nn.Module):
ops.SepConv(channels, channels, 3, stride, 1, affine=False),
ops.SepConv(channels, channels, 5, stride, 2, affine=False),
ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False),
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False),
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False)
],
key=choice_keys[-1]))
self.drop_path = ops.DropPath_()
......@@ -57,6 +57,7 @@ class Node(nn.Module):
def forward(self, prev_nodes):
assert len(self.ops) == len(prev_nodes)
out = [op(node) for op, node in zip(self.ops, prev_nodes)]
out = [self.drop_path(o) if o is not None else None for o in out]
return self.input_switch(out)
......
......@@ -4,9 +4,13 @@ import torch.nn as nn
class DropPath_(nn.Module):
def __init__(self, p=0.):
""" [!] DropPath is inplace module
Args:
p: probability of an path to be zeroed.
"""
DropPath is inplace module.
Parameters
----------
p : float
Probability of an path to be zeroed.
"""
super().__init__()
self.p = p
......@@ -26,13 +30,9 @@ class DropPath_(nn.Module):
class PoolBN(nn.Module):
"""
AvgPool or MaxPool - BN
AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`.
"""
def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
"""
Args:
pool_type: 'max' or 'avg'
"""
super().__init__()
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
......@@ -50,8 +50,8 @@ class PoolBN(nn.Module):
class StdConv(nn.Module):
""" Standard conv
ReLU - Conv - BN
"""
Standard conv: ReLU - Conv - BN
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__()
......@@ -66,8 +66,8 @@ class StdConv(nn.Module):
class FacConv(nn.Module):
""" Factorized conv
ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
super().__init__()
......@@ -83,10 +83,10 @@ class FacConv(nn.Module):
class DilConv(nn.Module):
""" (Dilated) depthwise separable conv
ReLU - (Dilated) depthwise separable - Pointwise - BN
If dilation == 2, 3x3 conv => 5x5 receptive field
5x5 conv => 9x9 receptive field
"""
(Dilated) depthwise separable conv.
ReLU - (Dilated) depthwise separable - Pointwise - BN.
If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field.
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super().__init__()
......@@ -103,8 +103,9 @@ class DilConv(nn.Module):
class SepConv(nn.Module):
""" Depthwise separable conv
DilConv(dilation=1) * 2
"""
Depthwise separable conv.
DilConv(dilation=1) * 2.
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__()
......@@ -119,7 +120,7 @@ class SepConv(nn.Module):
class FactorizedReduce(nn.Module):
"""
Reduce feature map size by factorized pointwise(stride=2).
Reduce feature map size by factorized pointwise (stride=2).
"""
def __init__(self, C_in, C_out, affine=True):
super().__init__()
......
......@@ -4,12 +4,13 @@ from argparse import ArgumentParser
import torch
import torch.nn as nn
from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.pytorch.utils import AverageMeter
from torch.utils.tensorboard import SummaryWriter
import datasets
import utils
from model import CNN
from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.pytorch.utils import AverageMeter
logger = logging.getLogger()
......@@ -23,6 +24,7 @@ logger.setLevel(logging.INFO)
logger.addHandler(std_out_info)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
writer = SummaryWriter()
def train(config, train_loader, model, optimizer, criterion, epoch):
......@@ -33,6 +35,7 @@ def train(config, train_loader, model, optimizer, criterion, epoch):
cur_step = epoch * len(train_loader)
cur_lr = optimizer.param_groups[0]['lr']
logger.info("Epoch %d LR %.6f", epoch, cur_lr)
writer.add_scalar("lr", cur_lr, global_step=cur_step)
model.train()
......@@ -54,6 +57,9 @@ def train(config, train_loader, model, optimizer, criterion, epoch):
losses.update(loss.item(), bs)
top1.update(accuracy["acc1"], bs)
top5.update(accuracy["acc5"], bs)
writer.add_scalar("loss/train", loss.item(), global_step=cur_step)
writer.add_scalar("acc1/train", accuracy["acc1"], global_step=cur_step)
writer.add_scalar("acc5/train", accuracy["acc5"], global_step=cur_step)
if step % config.log_frequency == 0 or step == len(train_loader) - 1:
logger.info(
......@@ -77,15 +83,15 @@ def validate(config, valid_loader, model, criterion, epoch, cur_step):
with torch.no_grad():
for step, (X, y) in enumerate(valid_loader):
X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
N = X.size(0)
bs = X.size(0)
logits = model(X)
loss = criterion(logits, y)
accuracy = utils.accuracy(logits, y, topk=(1, 5))
losses.update(loss.item(), N)
top1.update(accuracy["acc1"], N)
top5.update(accuracy["acc5"], N)
losses.update(loss.item(), bs)
top1.update(accuracy["acc1"], bs)
top5.update(accuracy["acc5"], bs)
if step % config.log_frequency == 0 or step == len(valid_loader) - 1:
logger.info(
......@@ -94,6 +100,10 @@ def validate(config, valid_loader, model, criterion, epoch, cur_step):
epoch + 1, config.epochs, step, len(valid_loader) - 1, losses=losses,
top1=top1, top5=top5))
writer.add_scalar("loss/test", losses.avg, global_step=cur_step)
writer.add_scalar("acc1/test", top1.avg, global_step=cur_step)
writer.add_scalar("acc5/test", top5.avg, global_step=cur_step)
logger.info("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, config.epochs, top1.avg))
return top1.avg
......
......@@ -7,8 +7,7 @@ import torch.nn as nn
import datasets
from model import CNN
from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint,
LearningRateScheduler)
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
from nni.nas.pytorch.darts import DartsTrainer
from utils import accuracy
......@@ -29,6 +28,7 @@ if __name__ == "__main__":
parser.add_argument("--batch-size", default=64, type=int)
parser.add_argument("--log-frequency", default=10, type=int)
parser.add_argument("--epochs", default=50, type=int)
parser.add_argument("--unrolled", default=False, action="store_true")
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10")
......@@ -48,5 +48,6 @@ if __name__ == "__main__":
dataset_valid=dataset_valid,
batch_size=args.batch_size,
log_frequency=args.log_frequency,
callbacks=[LearningRateScheduler(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
unrolled=args.unrolled,
callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
trainer.train()
......@@ -19,7 +19,7 @@ class ENASLayer(mutables.MutableScope):
PoolBranch('max', in_filters, out_filters, 3, 1, 1)
])
if len(prev_labels) > 0:
self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None, reduction="sum")
self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None)
else:
self.skipconnect = None
self.batch_norm = nn.BatchNorm2d(out_filters, affine=False)
......
......@@ -9,7 +9,7 @@ import datasets
from macro import GeneralNetwork
from micro import MicroNetwork
from nni.nas.pytorch import enas
from nni.nas.pytorch.callbacks import LearningRateScheduler, ArchitectureCheckpoint
from nni.nas.pytorch.callbacks import LRSchedulerCallback, ArchitectureCheckpoint
from utils import accuracy, reward_accuracy
logger = logging.getLogger()
......@@ -51,7 +51,7 @@ if __name__ == "__main__":
metrics=accuracy,
reward_function=reward_accuracy,
optimizer=optimizer,
callbacks=[LearningRateScheduler(lr_scheduler), ArchitectureCheckpoint("./checkpoints")],
callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")],
batch_size=args.batch_size,
num_epochs=num_epochs,
dataset_train=dataset_train,
......
......@@ -51,21 +51,22 @@ class BaseMutator(nn.Module):
def mutables(self):
return self._structured_mutables
@property
def forward(self, *inputs):
raise RuntimeError("Forward is undefined for mutators.")
def __setattr__(self, name, value):
if name == "model":
raise AttributeError("Attribute `model` can be set at most once, and you shouldn't use `self.model = model` to "
"include you network, as it will include all parameters in model into the mutator.")
return super().__setattr__(name, value)
def enter_mutable_scope(self, mutable_scope):
"""
Callback when forward of a MutableScope is entered.
Parameters
----------
mutable_scope: MutableScope
Returns
-------
None
mutable_scope : MutableScope
"""
pass
......@@ -75,11 +76,7 @@ class BaseMutator(nn.Module):
Parameters
----------
mutable_scope: MutableScope
Returns
-------
None
mutable_scope : MutableScope
"""
pass
......@@ -89,8 +86,8 @@ class BaseMutator(nn.Module):
Parameters
----------
mutable: LayerChoice
inputs: list of torch.Tensor
mutable : LayerChoice
inputs : list of torch.Tensor
Returns
-------
......@@ -105,8 +102,8 @@ class BaseMutator(nn.Module):
Parameters
----------
mutable: InputChoice
tensor_list: list of torch.Tensor
mutable : InputChoice
tensor_list : list of torch.Tensor
Returns
-------
......
......@@ -29,7 +29,7 @@ class Callback:
pass
class LearningRateScheduler(Callback):
class LRSchedulerCallback(Callback):
def __init__(self, scheduler, mode="epoch"):
super().__init__()
assert mode == "epoch"
......
import torch
from torch import nn as nn
from torch.nn import functional as F
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutator import Mutator
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
......
......@@ -2,27 +2,27 @@ import copy
import logging
import torch
from torch import nn as nn
import torch.nn as nn
from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup
from .mutator import DartsMutator
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class DartsTrainer(Trainer):
def __init__(self, model, loss, metrics,
optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
callbacks=None):
callbacks=None, arc_learning_rate=3.0E-4, unrolled=True):
super().__init__(model, mutator if mutator is not None else DartsMutator(model),
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
batch_size, workers, device, log_frequency, callbacks)
self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), 3.0E-4, betas=(0.5, 0.999),
self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), arc_learning_rate, betas=(0.5, 0.999),
weight_decay=1.0E-3)
self.unrolled = unrolled
n_train = len(self.dataset_train)
split = n_train // 2
indices = list(range(n_train))
......@@ -43,42 +43,32 @@ class DartsTrainer(Trainer):
def train_one_epoch(self, epoch):
self.model.train()
self.mutator.train()
lr = self.optimizer.param_groups[0]["lr"]
meters = AverageMeterGroup()
for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)):
trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device)
val_X, val_y = val_X.to(self.device), val_y.to(self.device)
# backup model for hessian
backup_model = copy.deepcopy(self.model.state_dict())
# cannot deepcopy model because it will break the reference
# phase 1. architecture step
self.ctrl_optim.zero_grad()
if self.unrolled:
self._unrolled_backward(trn_X, trn_y, val_X, val_y)
else:
self._backward(val_X, val_y)
self.ctrl_optim.step()
# phase 1. child network step
# phase 2: child network step
self.optimizer.zero_grad()
self.mutator.reset()
logits = self.model(trn_X)
loss = self.loss(logits, trn_y)
logits, loss = self._logits_and_loss(trn_X, trn_y)
loss.backward()
# gradient clipping
nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
nn.utils.clip_grad_norm_(self.model.parameters(), 5.) # gradient clipping
self.optimizer.step()
new_model = copy.deepcopy(self.model.state_dict())
# phase 2. architect step (alpha)
self.ctrl_optim.zero_grad()
# compute unrolled loss
self._unrolled_backward(trn_X, trn_y, val_X, val_y, backup_model, lr)
self.ctrl_optim.step()
self.model.load_state_dict(new_model)
metrics = self.metrics(logits, trn_y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch+1,
self.num_epochs, step+1, len(self.train_loader), meters)
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
def validate_one_epoch(self, epoch):
self.model.eval()
......@@ -92,31 +82,69 @@ class DartsTrainer(Trainer):
metrics = self.metrics(logits, y)
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch+1,
self.num_epochs, step+1, len(self.test_loader), meters)
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.test_loader), meters)
def _logits_and_loss(self, X, y):
self.mutator.reset()
logits = self.model(X)
loss = self.loss(logits, y)
return logits, loss
def _backward(self, val_X, val_y):
"""
Simple backward with gradient descent
"""
_, loss = self._logits_and_loss(val_X, val_y)
loss.backward()
def _unrolled_backward(self, trn_X, trn_y, val_X, val_y, backup_model, lr):
def _unrolled_backward(self, trn_X, trn_y, val_X, val_y):
"""
Compute unrolled loss and backward its gradients
Parameters
----------
v_model: backup model before this step
lr: learning rate for virtual gradient step (same as net lr)
"""
self.mutator.reset()
loss = self.loss(self.model(val_X), val_y)
w_model = tuple(self.model.parameters())
w_ctrl = tuple(self.mutator.parameters())
backup_params = copy.deepcopy(tuple(self.model.parameters()))
# do virtual step on training data
lr = self.optimizer.param_groups[0]["lr"]
momentum = self.optimizer.param_groups[0]["momentum"]
weight_decay = self.optimizer.param_groups[0]["weight_decay"]
self._compute_virtual_model(trn_X, trn_y, lr, momentum, weight_decay)
# calculate unrolled loss on validation data
# keep gradients for model here for compute hessian
_, loss = self._logits_and_loss(val_X, val_y)
w_model, w_ctrl = tuple(self.model.parameters()), tuple(self.mutator.parameters())
w_grads = torch.autograd.grad(loss, w_model + w_ctrl)
d_model = w_grads[:len(w_model)]
d_ctrl = w_grads[len(w_model):]
d_model, d_ctrl = w_grads[:len(w_model)], w_grads[len(w_model):]
hessian = self._compute_hessian(backup_model, d_model, trn_X, trn_y)
# compute hessian and final gradients
hessian = self._compute_hessian(backup_params, d_model, trn_X, trn_y)
with torch.no_grad():
for param, d, h in zip(w_ctrl, d_ctrl, hessian):
# gradient = dalpha - lr * hessian
param.grad = d - lr * h
def _compute_hessian(self, model, dw, trn_X, trn_y):
# restore weights
self._restore_weights(backup_params)
def _compute_virtual_model(self, X, y, lr, momentum, weight_decay):
"""
Compute unrolled weights w`
"""
# don't need zero_grad, using autograd to calculate gradients
_, loss = self._logits_and_loss(X, y)
gradients = torch.autograd.grad(loss, self.model.parameters())
with torch.no_grad():
for w, g in zip(self.model.parameters(), gradients):
m = self.optimizer.state[w].get("momentum_buffer", 0.)
w = w - lr * (momentum * m + g + weight_decay * w)
def _restore_weights(self, backup_params):
with torch.no_grad():
for param, backup in zip(self.model.parameters(), backup_params):
param.copy_(backup)
def _compute_hessian(self, backup_params, dw, trn_X, trn_y):
"""
dw = dw` { L_val(w`, alpha) }
w+ = w + eps * dw
......@@ -124,23 +152,22 @@ class DartsTrainer(Trainer):
hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps)
eps = 0.01 / ||dw||
"""
self.model.load_state_dict(model)
self._restore_weights(backup_params)
norm = torch.cat([w.view(-1) for w in dw]).norm()
eps = 0.01 / norm
if norm < 1E-8:
logger.warning("In computing hessian, norm is smaller than 1E-8, cause eps to be %.6f.", norm.item())
dalphas = []
for e in [eps, -2. * eps]:
# w+ = w + eps*dw`, w- = w - eps*dw`
with torch.no_grad():
for p, d in zip(self.model.parameters(), dw):
p += eps * d
p += e * d
self.mutator.reset()
loss = self.loss(self.model(trn_X), trn_y)
if e > 0:
dalpha_pos = torch.autograd.grad(loss, self.mutator.parameters()) # dalpha { L_trn(w+) }
elif e < 0:
dalpha_neg = torch.autograd.grad(loss, self.mutator.parameters()) # dalpha { L_trn(w-) }
_, loss = self._logits_and_loss(trn_X, trn_y)
dalphas.append(torch.autograd.grad(loss, self.mutator.parameters()))
dalpha_pos, dalpha_neg = dalphas # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
hessian = [(p - n) / 2. * eps for p, n in zip(dalpha_pos, dalpha_neg)]
return hessian
......@@ -25,6 +25,7 @@ class StackedLSTMCell(nn.Module):
class EnasMutator(Mutator):
def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False,
skip_target=0.4, branch_bias=0.25):
super().__init__(model)
......@@ -51,7 +52,7 @@ class EnasMutator(Mutator):
self.max_layer_choice = mutable.length
assert self.max_layer_choice == mutable.length, \
"ENAS mutator requires all layer choice have the same number of candidates."
# NOTE(yuge): We might implement an interface later. Judging by key now.
# We are judging by keys and module types to add biases to layer choices. Needs refactor.
if "reduce" in mutable.key:
def is_conv(choice):
return "conv" in str(type(choice)).lower()
......
......@@ -6,9 +6,7 @@ from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup
from .mutator import EnasMutator
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class EnasTrainer(Trainer):
......@@ -75,8 +73,8 @@ class EnasTrainer(Trainer):
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Model Epoch [%s/%s] Step [%s/%s] %s", epoch,
self.num_epochs, step, len(self.train_loader), meters)
logger.info("Model Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
# Train sampler (mutator)
self.model.eval()
......@@ -114,8 +112,8 @@ class EnasTrainer(Trainer):
self.mutator_optim.zero_grad()
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("RL Epoch [%s/%s] Step [%s/%s] %s", epoch, self.num_epochs,
mutator_step // self.mutator_steps_aggregate, self.mutator_steps, meters)
logger.info("RL Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, self.num_epochs,
mutator_step // self.mutator_steps_aggregate + 1, self.mutator_steps, meters)
mutator_step += 1
if mutator_step >= total_mutator_steps:
break
......
......@@ -14,11 +14,11 @@ class FixedArchitecture(Mutator):
Parameters
----------
model: nn.Module
model : nn.Module
A mutable network.
fixed_arc: str or dict
fixed_arc : str or dict
Path to the architecture checkpoint (a string), or preloaded architecture object (a dict).
strict: bool
strict : bool
Force everything that appears in `fixed_arc` to be used at least once.
"""
super().__init__(model)
......@@ -55,11 +55,11 @@ def apply_fixed_architecture(model, fixed_arc_path, device=None):
Parameters
----------
model: torch.nn.Module
model : torch.nn.Module
Model with mutables.
fixed_arc_path: str
fixed_arc_path : str
Path to the JSON that stores the architecture.
device: torch.device
device : torch.device
Architecture weights will be transfered to `device`.
Returns
......@@ -76,3 +76,4 @@ def apply_fixed_architecture(model, fixed_arc_path, device=None):
architecture = FixedArchitecture(model, fixed_arc)
architecture.to(device)
architecture.reset()
return architecture
......@@ -39,6 +39,9 @@ class Mutable(nn.Module):
return super().__call__(*args, **kwargs)
def set_mutator(self, mutator):
if "mutator" in self.__dict__:
raise RuntimeError("`set_mutator` is called more than once. Did you parse the search space multiple times? "
"Or did you apply multiple fixed architectures?")
self.__dict__["mutator"] = mutator
def forward(self, *inputs):
......@@ -68,9 +71,10 @@ class Mutable(nn.Module):
class MutableScope(Mutable):
"""
Mutable scope labels a subgraph/submodule to help mutators make better decisions.
Mutable scope marks a subgraph/submodule to help mutators make better decisions.
Mutators get notified when a mutable scope is entered and exited. Mutators can override ``enter_mutable_scope``
and ``exit_mutable_scope`` to catch corresponding events, and do status dump or update.
MutableScope are also mutables that are listed in the mutables (search space).
"""
def __init__(self, key):
......@@ -86,7 +90,7 @@ class MutableScope(Mutable):
class LayerChoice(Mutable):
def __init__(self, op_candidates, reduction="mean", return_mask=False, key=None):
def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None):
super().__init__(key=key)
self.length = len(op_candidates)
self.choices = nn.ModuleList(op_candidates)
......@@ -117,25 +121,25 @@ class InputChoice(Mutable):
NO_KEY = ""
def __init__(self, n_candidates=None, choose_from=None, n_chosen=None,
reduction="mean", return_mask=False, key=None):
reduction="sum", return_mask=False, key=None):
"""
Initialization.
Parameters
----------
n_candidates: int
n_candidates : int
Number of inputs to choose from.
choose_from: list of str
choose_from : list of str
List of source keys to choose from. At least of one of `choose_from` and `n_candidates` must be fulfilled.
If `n_candidates` has a value but `choose_from` is None, it will be automatically treated as `n_candidates`
number of empty string.
n_chosen: int
n_chosen : int
Recommended inputs to choose. If None, mutator is instructed to select any.
reduction: str
reduction : str
`mean`, `concat`, `sum` or `none`.
return_mask: bool
return_mask : bool
If `return_mask`, return output tensor and a mask. Otherwise return tensor only.
key: str
key : str
Key of the input choice.
"""
super().__init__(key=key)
......@@ -163,7 +167,7 @@ class InputChoice(Mutable):
Parameters
----------
optional_inputs: list or dict
optional_inputs : list or dict
Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of
`choose_from` in initialization. As a list, inputs must follow the semantic order that is the same as
`choose_from`.
......
import logging
import torch
from nni.nas.pytorch.base_mutator import BaseMutator
logger = logging.getLogger(__name__)
class Mutator(BaseMutator):
......@@ -60,8 +64,8 @@ class Mutator(BaseMutator):
Parameters
----------
mutable: LayerChoice
inputs: list of torch.Tensor
mutable : LayerChoice
inputs : list of torch.Tensor
Returns
-------
......@@ -85,9 +89,9 @@ class Mutator(BaseMutator):
Parameters
----------
mutable: InputChoice
tensor_list: list of torch.Tensor
tags: list of string
mutable : InputChoice
tensor_list : list of torch.Tensor
tags : list of string
Returns
-------
......@@ -108,7 +112,7 @@ class Mutator(BaseMutator):
return out
def _tensor_reduction(self, reduction_type, tensor_list):
if tensor_list == "none":
if reduction_type == "none":
return tensor_list
if not tensor_list:
return None # empty. return None for now
......@@ -129,12 +133,14 @@ class Mutator(BaseMutator):
Parameters
----------
mutable: Mutable
mutable : Mutable
Returns
-------
any
object
"""
if mutable.key not in self._cache:
raise ValueError("\"{}\" not found in decision cache.".format(mutable.key))
return self._cache[mutable.key]
result = self._cache[mutable.key]
logger.debug("Decision %s: %s", mutable.key, result)
return result
......@@ -4,7 +4,7 @@
import copy
import numpy as np
from torch.nn import functional as F
import torch.nn.functional as F
from nni.nas.pytorch.darts import DartsMutator
from nni.nas.pytorch.mutables import LayerChoice
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from nni.nas.pytorch.callbacks import LearningRateScheduler
from nni.nas.pytorch.callbacks import LRSchedulerCallback
from nni.nas.pytorch.darts import DartsTrainer
from nni.nas.pytorch.trainer import BaseTrainer
......@@ -50,7 +50,7 @@ class PdartsTrainer(BaseTrainer):
darts_callbacks = []
if lr_scheduler is not None:
darts_callbacks.append(LearningRateScheduler(lr_scheduler))
darts_callbacks.append(LRSchedulerCallback(lr_scheduler))
self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion, optimizer=optim,
callbacks=darts_callbacks, **self.darts_parameters)
......
......@@ -24,6 +24,40 @@ class TorchTensorEncoder(json.JSONEncoder):
class Trainer(BaseTrainer):
def __init__(self, model, mutator, loss, metrics, optimizer, num_epochs,
dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks):
"""
Trainer initialization.
Parameters
----------
model : nn.Module
Model with mutables.
mutator : BaseMutator
A mutator object that has been initialized with the model.
loss : callable
Called with logits and targets. Returns a loss tensor.
metrics : callable
Returns a dict that maps metrics keys to metrics data.
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
dataset_train : torch.utils.data.Dataset
Dataset of training.
dataset_valid : torch.utils.data.Dataset
Dataset of validation/testing.
batch_size : int
Batch size.
workers : int
Number of workers used in data preprocessing.
device : torch.device
Device object. Either `torch.device("cuda")` or torch.device("cpu")`. When `None`, trainer will
automatic detects GPU and selects GPU first.
log_frequency : int
Number of mini-batches to log metrics.
callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks.
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.model = model
self.mutator = mutator
......
......@@ -28,6 +28,16 @@ class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
"""
Initialization of AverageMeter
Parameters
----------
name : str
Name to display.
fmt : str
Format string to print the values.
"""
self.name = name
self.fmt = fmt
self.reset()
......@@ -78,12 +88,12 @@ class StructuredMutableTreeNode:
Parameters
----------
order: str
order : str
pre or post. If pre, current mutable is yield before children. Otherwise after.
deduplicate: bool
deduplicate : bool
If true, mutables with the same key will not appear after the first appearance.
memo: dict
An auxiliary variable to make deduplicate happen.
memo : dict
An auxiliary dict that memorize keys seen before, so that deduplication is possible.
Returns
-------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment