"docs/en_US/vscode:/vscode.git/clone" did not exist on "0717988f2f45baeb8483536e877defd92c82cfae"
Commit 77e91e8b authored by Yuge Zhang's avatar Yuge Zhang Committed by Chi Song
Browse files

Extract controller from mutator to make offline decisions (#1758)

parent 9dda5370
...@@ -35,8 +35,7 @@ class PdartsTrainer(Trainer): ...@@ -35,8 +35,7 @@ class PdartsTrainer(Trainer):
layers = self.layers+self.pdarts_num_layers[epoch] layers = self.layers+self.pdarts_num_layers[epoch]
model, loss, model_optim, _ = self.model_creator( model, loss, model_optim, _ = self.model_creator(
layers, n_nodes) layers, n_nodes)
mutator = PdartsMutator( mutator = PdartsMutator(model, epoch, self.pdarts_num_to_drop, switches) # pylint: disable=too-many-function-args
model, epoch, self.pdarts_num_to_drop, switches)
self.trainer = DartsTrainer(model, loss=loss, optimizer=model_optim, self.trainer = DartsTrainer(model, loss=loss, optimizer=model_optim,
mutator=mutator, **self.darts_parameters) mutator=mutator, **self.darts_parameters)
......
import json
import logging
from abc import abstractmethod from abc import abstractmethod
import torch import torch
from .base_trainer import BaseTrainer from .base_trainer import BaseTrainer
_logger = logging.getLogger(__name__)
class TorchTensorEncoder(json.JSONEncoder):
def default(self, o): # pylint: disable=method-hidden
if isinstance(o, torch.Tensor):
olist = o.tolist()
if "bool" not in o.type().lower() and all(map(lambda d: d == 0 or d == 1, olist)):
_logger.warning("Every element in %s is either 0 or 1. "
"You might consider convert it into bool.", olist)
return olist
return super().default(o)
class Trainer(BaseTrainer): class Trainer(BaseTrainer):
def __init__(self, model, loss, metrics, optimizer, num_epochs, def __init__(self, model, mutator, loss, metrics, optimizer, num_epochs,
dataset_train, dataset_valid, batch_size, workers, device, log_frequency, dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks):
mutator, callbacks):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.model = model self.model = model
self.mutator = mutator
self.loss = loss self.loss = loss
self.metrics = metrics self.metrics = metrics
self.optimizer = optimizer self.optimizer = optimizer
self.mutator = mutator
self.model.to(self.device) self.model.to(self.device)
self.loss.to(self.device)
self.mutator.to(self.device) self.mutator.to(self.device)
self.loss.to(self.device)
self.num_epochs = num_epochs self.num_epochs = num_epochs
self.dataset_train = dataset_train self.dataset_train = dataset_train
...@@ -38,7 +53,7 @@ class Trainer(BaseTrainer): ...@@ -38,7 +53,7 @@ class Trainer(BaseTrainer):
def validate_one_epoch(self, epoch): def validate_one_epoch(self, epoch):
pass pass
def _train(self, validate): def train(self, validate=True):
for epoch in range(self.num_epochs): for epoch in range(self.num_epochs):
for callback in self.callbacks: for callback in self.callbacks:
callback.on_epoch_begin(epoch) callback.on_epoch_begin(epoch)
...@@ -55,11 +70,13 @@ class Trainer(BaseTrainer): ...@@ -55,11 +70,13 @@ class Trainer(BaseTrainer):
for callback in self.callbacks: for callback in self.callbacks:
callback.on_epoch_end(epoch) callback.on_epoch_end(epoch)
def train_and_validate(self):
self._train(True)
def train(self):
self._train(False)
def validate(self): def validate(self):
self.validate_one_epoch(-1) self.validate_one_epoch(-1)
def export(self, file):
mutator_export = self.mutator.export()
with open(file, "w") as f:
json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
def checkpoint(self):
raise NotImplementedError("Not implemented yet")
from collections import OrderedDict
_counter = 0
def global_mutable_counting():
global _counter
_counter += 1
return _counter
class AverageMeterGroup:
def __init__(self):
self.meters = OrderedDict()
def update(self, data):
for k, v in data.items():
if k not in self.meters:
self.meters[k] = AverageMeter(k, ":4f")
self.meters[k].update(v)
def __str__(self):
return " ".join(str(v) for _, v in self.meters.items())
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
class StructuredMutableTreeNode:
"""
A structured representation of a search space.
A search space comes with a root (with `None` stored in its `mutable`), and a bunch of children in its `children`.
This tree can be seen as a "flattened" version of the module tree. Since nested mutable entity is not supported yet,
the following must be true: each subtree corresponds to a ``MutableScope`` and each leaf corresponds to a
``Mutable`` (other than ``MutableScope``).
"""
def __init__(self, mutable):
self.mutable = mutable
self.children = []
def add_child(self, mutable):
self.children.append(StructuredMutableTreeNode(mutable))
return self.children[-1]
def type(self):
return type(self.mutable)
def __iter__(self):
return self.traverse()
def traverse(self, order="pre", deduplicate=True, memo=None):
"""
Return a generator that generates a list of mutables in this tree.
Parameters
----------
order: str
pre or post. If pre, current mutable is yield before children. Otherwise after.
deduplicate: bool
If true, mutables with the same key will not appear after the first appearance.
memo: dict
An auxiliary variable to make deduplicate happen.
Returns
-------
generator of Mutable
"""
if memo is None:
memo = set()
assert order in ["pre", "post"]
if order == "pre":
if self.mutable is not None:
if not deduplicate or self.mutable.key not in memo:
memo.add(self.mutable.key)
yield self.mutable
for child in self.children:
for m in child.traverse(order=order, deduplicate=deduplicate, memo=memo):
yield m
if order == "post":
if self.mutable is not None:
if not deduplicate or self.mutable.key not in memo:
memo.add(self.mutable.key)
yield self.mutable
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment