Unverified Commit 76c819c0 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Merge dev-nas-tf to master (#2459)

parent 69cc137a
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import tensorflow as tf
from tensorflow.data import Dataset
def get_dataset():
(x_train, y_train), (x_valid, y_valid) = tf.keras.datasets.cifar10.load_data()
x_train, x_valid = x_train / 255.0, x_valid / 255.0
train_set = (x_train, y_train)
valid_set = (x_valid, y_valid)
return train_set, valid_set
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import tensorflow as tf
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import (
AveragePooling2D,
BatchNormalization,
Conv2D,
Dense,
Dropout,
GlobalAveragePooling2D,
MaxPool2D,
ReLU,
SeparableConv2D,
)
from nni.nas.tensorflow.mutables import InputChoice, LayerChoice, MutableScope
def build_conv(filters, kernel_size, name=None):
return Sequential([
Conv2D(filters, kernel_size=1, use_bias=False),
BatchNormalization(trainable=False),
ReLU(),
Conv2D(filters, kernel_size, padding='same'),
BatchNormalization(trainable=False),
ReLU(),
], name)
def build_separable_conv(filters, kernel_size, name=None):
return Sequential([
Conv2D(filters, kernel_size=1, use_bias=False),
BatchNormalization(trainable=False),
ReLU(),
SeparableConv2D(filters, kernel_size, padding='same', use_bias=False),
Conv2D(filters, kernel_size=1, use_bias=False),
BatchNormalization(trainable=False),
ReLU(),
], name)
def build_avg_pool(filters, name=None):
return Sequential([
Conv2D(filters, kernel_size=1, use_bias=False),
BatchNormalization(trainable=False),
ReLU(),
AveragePooling2D(pool_size=3, strides=1, padding='same'),
BatchNormalization(trainable=False),
], name)
def build_max_pool(filters, name=None):
return Sequential([
Conv2D(filters, kernel_size=1, use_bias=False),
BatchNormalization(trainable=False),
ReLU(),
MaxPool2D(pool_size=3, strides=1, padding='same'),
BatchNormalization(trainable=False),
], name)
class FactorizedReduce(Model):
def __init__(self, filters):
super().__init__()
self.conv1 = Conv2D(filters // 2, kernel_size=1, strides=2, use_bias=False)
self.conv2 = Conv2D(filters // 2, kernel_size=1, strides=2, use_bias=False)
self.bn = BatchNormalization(trainable=False)
def call(self, x):
out1 = self.conv1(x)
out2 = self.conv2(x[:, 1:, 1:, :])
out = tf.concat([out1, out2], axis=3)
out = self.bn(out)
return out
class ENASLayer(MutableScope):
def __init__(self, key, prev_labels, filters):
super().__init__(key)
self.mutable = LayerChoice([
build_conv(filters, 3, 'conv3'),
build_separable_conv(filters, 3, 'sepconv3'),
build_conv(filters, 5, 'conv5'),
build_separable_conv(filters, 5, 'sepconv5'),
build_avg_pool(filters, 'avgpool'),
build_max_pool(filters, 'maxpool'),
])
if len(prev_labels) > 0:
self.skipconnect = InputChoice(choose_from=prev_labels, n_chosen=None)
else:
self.skipconnect = None
self.batch_norm = BatchNormalization(trainable=False)
def call(self, prev_layers):
out = self.mutable(prev_layers[-1])
if self.skipconnect is not None:
connection = self.skipconnect(prev_layers[:-1])
if connection is not None:
out += connection
return self.batch_norm(out)
class GeneralNetwork(Model):
def __init__(self, num_layers=12, filters=24, num_classes=10, dropout_rate=0.0):
super().__init__()
self.num_layers = num_layers
self.stem = Sequential([
Conv2D(filters, kernel_size=3, padding='same', use_bias=False),
BatchNormalization()
])
labels = ['layer_{}'.format(i) for i in range(num_layers)]
self.enas_layers = []
for i in range(num_layers):
layer = ENASLayer(labels[i], labels[:i], filters)
self.enas_layers.append(layer)
pool_num = 2
self.pool_distance = num_layers // (pool_num + 1)
self.pool_layers = [FactorizedReduce(filters) for _ in range(pool_num)]
self.gap = GlobalAveragePooling2D()
self.dropout = Dropout(dropout_rate)
self.dense = Dense(num_classes)
def call(self, x):
cur = self.stem(x)
prev_outputs = [cur]
for i, layer in enumerate(self.enas_layers):
if i > 0 and i % self.pool_distance == 0:
pool = self.pool_layers[i // self.pool_distance - 1]
prev_outputs = [pool(tensor) for tensor in prev_outputs]
cur = prev_outputs[-1]
cur = layer(prev_outputs)
prev_outputs.append(cur)
cur = self.gap(cur)
cur = self.dropout(cur)
logits = self.dense(cur)
return logits
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import tensorflow as tf
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import (
AveragePooling2D,
BatchNormalization,
Conv2D,
Dense,
Dropout,
GlobalAveragePooling2D,
MaxPool2D,
ReLU,
SeparableConv2D,
)
from nni.nas.tensorflow.mutables import InputChoice, LayerChoice, MutableScope
def build_conv_1x1(filters, name=None):
return Sequential([
Conv2D(filters, kernel_size=1, use_bias=False),
BatchNormalization(trainable=False),
ReLU(),
], name)
def build_sep_conv(filters, kernel_size, name=None):
return Sequential([
ReLU(),
SeparableConv2D(filters, kernel_size, padding='same'),
BatchNormalization(trainable=True),
], name)
class FactorizedReduce(Model):
def __init__(self, filters):
super().__init__()
self.conv1 = Conv2D(filters // 2, kernel_size=1, strides=2, use_bias=False)
self.conv2 = Conv2D(filters // 2, kernel_size=1, strides=2, use_bias=False)
self.bn = BatchNormalization(trainable=False)
def call(self, x):
out1 = self.conv1(x)
out2 = self.conv2(x[:, 1:, 1:, :])
out = tf.concat([out1, out2], axis=3)
out = self.bn(out)
return out
class ReductionLayer(Model):
def __init__(self, filters):
super().__init__()
self.reduce0 = FactorizedReduce(filters)
self.reduce1 = FactorizedReduce(filters)
def call(self, prevprev, prev):
return self.reduce0(prevprev), self.reduce1(prev)
class Calibration(Model):
def __init__(self, filters):
super().__init__()
self.filters = filters
self.process = None
def build(self, shape):
assert len(shape) == 4 # batch_size, width, height, filters
if shape[3] != self.filters:
self.process = build_conv_1x1(self.filters)
def call(self, x):
if self.process is None:
return x
return self.process(x)
class Cell(Model):
def __init__(self, cell_name, prev_labels, filters):
super().__init__()
self.input_choice = InputChoice(choose_from=prev_labels, n_chosen=1, return_mask=True, key=cell_name + '_input')
self.op_choice = LayerChoice([
build_sep_conv(filters, 3),
build_sep_conv(filters, 5),
AveragePooling2D(pool_size=3, strides=1, padding='same'),
MaxPool2D(pool_size=3, strides=1, padding='same'),
Sequential(), # Identity
], key=cell_name + '_op')
def call(self, prev_layers):
chosen_input, chosen_mask = self.input_choice(prev_layers)
cell_out = self.op_choice(chosen_input)
return cell_out, chosen_mask
class Node(MutableScope):
def __init__(self, node_name, prev_node_names, filters):
super().__init__(node_name)
self.cell_x = Cell(node_name + '_x', prev_node_names, filters)
self.cell_y = Cell(node_name + '_y', prev_node_names, filters)
def call(self, prev_layers):
out_x, mask_x = self.cell_x(prev_layers)
out_y, mask_y = self.cell_y(prev_layers)
return out_x + out_y, mask_x | mask_y
class ENASLayer(Model):
def __init__(self, num_nodes, filters, reduction):
super().__init__()
self.preproc0 = Calibration(filters)
self.preproc1 = Calibration(filters)
self.nodes = []
node_labels = [InputChoice.NO_KEY, InputChoice.NO_KEY]
name_prefix = 'reduce' if reduction else 'normal'
for i in range(num_nodes):
node_labels.append('{}_node_{}'.format(name_prefix, i))
self.nodes.append(Node(node_labels[-1], node_labels[:-1], filters))
self.conv_ops = [Conv2D(filters, kernel_size=1, padding='same', use_bias=False) for _ in range(num_nodes + 2)]
self.bn = BatchNormalization(trainable=False)
def call(self, prevprev, prev):
prev_nodes_out = [self.preproc0(prevprev), self.preproc1(prev)]
nodes_used_mask = tf.zeros(len(self.nodes) + 2, dtype=tf.bool)
for i, node in enumerate(self.nodes):
node_out, mask = node(prev_nodes_out)
nodes_used_mask |= tf.pad(mask, [[0, nodes_used_mask.shape[0] - mask.shape[0]]])
prev_nodes_out.append(node_out)
outputs = []
for used, out, conv in zip(nodes_used_mask.numpy(), prev_nodes_out, self.conv_ops):
if not used:
outputs.append(conv(out))
out = tf.add_n(outputs)
return prev, self.bn(out)
class MicroNetwork(Model):
def __init__(self, num_layers=6, num_nodes=5, out_channels=20, num_classes=10, dropout_rate=0.1):
super().__init__()
self.num_layers = num_layers
self.stem = Sequential([
Conv2D(out_channels * 3, kernel_size=3, padding='same', use_bias=False),
BatchNormalization(),
])
pool_distance = num_layers // 3
pool_layer_indices = [pool_distance, 2 * pool_distance + 1]
self.enas_layers = []
filters = out_channels
for i in range(num_layers + 2):
if i in pool_layer_indices:
reduction = True
filters *= 2
self.enas_layers.append(ReductionLayer(filters))
else:
reduction = False
self.enas_layers.append(ENASLayer(num_nodes, filters, reduction))
self.gap = GlobalAveragePooling2D()
self.dropout = Dropout(dropout_rate)
self.dense = Dense(num_classes)
def call(self, x):
prev = cur = self.stem(x)
for layer in self.enas_layers:
prev, cur = layer(prev, cur)
cur = tf.keras.activations.relu(cur)
cur = self.gap(cur)
cur = self.dropout(cur)
logits = self.dense(cur)
return logits
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from tensorflow.keras.losses import Reduction, SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import SGD
from nni.nas.tensorflow import enas
import datasets
from macro import GeneralNetwork
from micro import MicroNetwork
from utils import accuracy, accuracy_metrics
# TODO: argparse
dataset_train, dataset_valid = datasets.get_dataset()
#model = GeneralNetwork()
model = MicroNetwork()
loss = SparseCategoricalCrossentropy(from_logits=True, reduction=Reduction.NONE)
optimizer = SGD(learning_rate=0.05, momentum=0.9)
trainer = enas.EnasTrainer(model,
loss=loss,
metrics=accuracy_metrics,
reward_function=accuracy,
optimizer=optimizer,
batch_size=64,
num_epochs=310,
dataset_train=dataset_train,
dataset_valid=dataset_valid)
trainer.train()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import tensorflow as tf
def accuracy_metrics(y_true, logits):
return {'enas_acc': accuracy(y_true, logits)}
def accuracy(y_true, logits):
# y_true: shape=(batch_size) or (batch_size,1), type=integer
# logits: shape=(batch_size, num_of_classes), type=float
# returns float
batch_size = y_true.shape[0]
y_true = tf.squeeze(y_true)
y_pred = tf.math.argmax(logits, axis=1)
y_pred = tf.cast(y_pred, y_true.dtype)
equal = tf.cast(y_pred == y_true, tf.int32)
return tf.math.reduce_sum(equal).numpy() / batch_size
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import (AveragePooling2D, BatchNormalization, Conv2D, Dense, MaxPool2D)
from tensorflow.keras.losses import Reduction, SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import SGD
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice
from nni.nas.tensorflow.enas import EnasTrainer
tf.get_logger().setLevel('ERROR')
class Net(Model):
def __init__(self):
super().__init__()
self.conv1 = LayerChoice([
Conv2D(6, 3, padding='same', activation='relu'),
Conv2D(6, 5, padding='same', activation='relu'),
])
self.pool = MaxPool2D(2)
self.conv2 = LayerChoice([
Conv2D(16, 3, padding='same', activation='relu'),
Conv2D(16, 5, padding='same', activation='relu'),
])
self.conv3 = Conv2D(16, 1)
self.skipconnect = InputChoice(n_candidates=1)
self.bn = BatchNormalization()
self.gap = AveragePooling2D(2)
self.fc1 = Dense(120, activation='relu')
self.fc2 = Dense(84, activation='relu')
self.fc3 = Dense(10)
def call(self, x):
bs = x.shape[0]
t = self.conv1(x)
x = self.pool(t)
x0 = self.conv2(x)
x1 = self.conv3(x0)
x0 = self.skipconnect([x0])
if x0 is not None:
x1 += x0
x = self.pool(self.bn(x1))
x = self.gap(x)
x = tf.reshape(x, [bs, -1])
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
def accuracy(output, target):
bs = target.shape[0]
predicted = tf.cast(tf.argmax(output, 1), target.dtype)
target = tf.reshape(target, [-1])
return sum(tf.cast(predicted == target, tf.float32)) / bs
if __name__ == '__main__':
cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
split = int(len(x_train) * 0.9)
dataset_train = tf.data.Dataset.from_tensor_slices((x_train[:split], y_train[:split])).batch(64)
dataset_valid = tf.data.Dataset.from_tensor_slices((x_train[split:], y_train[split:])).batch(64)
dataset_test = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(64)
net = Net()
trainer = EnasTrainer(
net,
loss=SparseCategoricalCrossentropy(reduction=Reduction.SUM),
metrics=accuracy,
reward_function=accuracy,
optimizer=SGD(learning_rate=0.001, momentum=0.9),
batch_size=64,
num_epochs=2,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
dataset_test=dataset_test
)
trainer.train()
#trainer.export('checkpoint')
......@@ -45,4 +45,6 @@ enable= unused-wildcard-import,
ignore-patterns=test*
# List of members which are set dynamically and missed by pylint inference
generated-members=numpy.*,torch.*
generated-members=numpy.*,torch.*,tensorflow.*
ignored-modules=tensorflow
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from tensorflow.keras import Model
from .mutables import Mutable, MutableScope, InputChoice
from .utils import StructuredMutableTreeNode
class BaseMutator(Model):
def __init__(self, model):
super().__init__()
self.__dict__['model'] = model
self._structured_mutables = self._parse_search_space(self.model)
def _parse_search_space(self, module, root=None, prefix='', memo=None, nested_detection=None):
if memo is None:
memo = set()
if root is None:
root = StructuredMutableTreeNode(None)
if module not in memo:
memo.add(module)
if isinstance(module, Mutable):
if nested_detection is not None:
raise RuntimeError('Cannot have nested search space. Error at {} in {}'
.format(module, nested_detection))
module.name = prefix
module.set_mutator(self)
root = root.add_child(module)
if not isinstance(module, MutableScope):
nested_detection = module
if isinstance(module, InputChoice):
for k in module.choose_from:
if k != InputChoice.NO_KEY and k not in [m.key for m in memo if isinstance(m, Mutable)]:
raise RuntimeError('"{}" required by "{}" not found in keys that appeared before, and is not NO_KEY.'
.format(k, module.key))
for submodule in module.layers:
if not isinstance(submodule, Model):
continue
submodule_prefix = prefix + ('.' if prefix else '') + submodule.name
self._parse_search_space(submodule, root, submodule_prefix, memo=memo, nested_detection=nested_detection)
return root
@property
def mutables(self):
return self._structured_mutables
def undedup_mutables(self):
return self._structured_mutables.traverse(deduplicate=False)
def call(self, *inputs):
raise RuntimeError('Call is undefined for mutators.')
def __setattr__(self, name, value):
if name == 'model':
raise AttributeError("Attribute `model` can be set at most once, and you shouldn't use `self.model = model` to "
"include your network, as it will include all parameters in model into the mutator.")
return super().__setattr__(name, value)
def enter_mutable_scope(self, mutable_scope):
pass
def exit_mutable_scope(self, mutable_scope):
pass
def on_forward_layer_choice(self, mutable, *inputs):
raise NotImplementedError
def on_forward_input_choice(self, mutable, tensor_list):
raise NotImplementedError
def export(self):
raise NotImplementedError
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import EnasMutator
from .trainer import EnasTrainer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import tensorflow as tf
from tensorflow.keras.layers import Dense, Embedding, LSTMCell, RNN
from tensorflow.keras.losses import SparseCategoricalCrossentropy, Reduction
from nni.nas.tensorflow.mutator import Mutator
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice, MutableScope
class EnasMutator(Mutator):
def __init__(self, model,
lstm_size=64,
lstm_num_layers=1,
tanh_constant=1.5,
cell_exit_extra_step=False,
skip_target=0.4,
temperature=None,
branch_bias=0.25,
entropy_reduction='sum'):
super().__init__(model)
self.tanh_constant = tanh_constant
self.temperature = temperature
self.cell_exit_extra_step = cell_exit_extra_step
cells = [LSTMCell(units=lstm_size, use_bias=False) for _ in range(lstm_num_layers)]
self.lstm = RNN(cells, stateful=True)
self.g_emb = tf.random.normal((1, 1, lstm_size)) * 0.1
self.skip_targets = tf.constant([1.0 - skip_target, skip_target])
self.max_layer_choice = 0
self.bias_dict = {}
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
if self.max_layer_choice == 0:
self.max_layer_choice = len(mutable)
assert self.max_layer_choice == len(mutable), \
"ENAS mutator requires all layer choice have the same number of candidates."
if 'reduce' in mutable.key:
bias = []
for choice in mutable.choices:
if 'conv' in str(type(choice)).lower():
bias.append(branch_bias)
else:
bias.append(-branch_bias)
self.bias_dict[mutable.key] = tf.constant(bias)
# exposed for trainer
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
# internal nn layers
self.embedding = Embedding(self.max_layer_choice + 1, lstm_size)
self.soft = Dense(self.max_layer_choice, use_bias=False)
self.attn_anchor = Dense(lstm_size, use_bias=False)
self.attn_query = Dense(lstm_size, use_bias=False)
self.v_attn = Dense(1, use_bias=False)
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
self.entropy_reduction = tf.reduce_sum if entropy_reduction == 'sum' else tf.reduce_mean
self.cross_entropy_loss = SparseCategoricalCrossentropy(from_logits=True, reduction=Reduction.NONE)
self._first_sample = True
def sample_search(self):
self._initialize()
self._sample(self.mutables)
self._first_sample = False
return self._choices
def sample_final(self):
return self.sample_search()
def _sample(self, tree):
mutable = tree.mutable
if isinstance(mutable, LayerChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_layer_choice(mutable)
elif isinstance(mutable, InputChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_input_choice(mutable)
for child in tree.children:
self._sample(child)
if self.cell_exit_extra_step and isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid:
self._anchors_hid[mutable.key] = self.lstm(self._inputs, 1)
def _initialize(self):
self._choices = {}
self._anchors_hid = {}
self._inputs = self.g_emb
# seems the `input_shape` parameter of RNN does not work
# workaround it by omitting `reset_states` for first run
if not self._first_sample:
self.lstm.reset_states()
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
def _sample_layer_choice(self, mutable):
logit = self.soft(self.lstm(self._inputs))
if self.temperature is not None:
logit /= self.temperature
if self.tanh_constant is not None:
logit = self.tanh_constant * tf.tanh(logit)
if mutable.key in self.bias_dict:
logit += self.bias_dict[mutable.key]
softmax_logit = tf.math.log(tf.nn.softmax(logit, axis=-1))
branch_id = tf.reshape(tf.random.categorical(softmax_logit, num_samples=1), [1])
log_prob = self.cross_entropy_loss(branch_id, logit)
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = log_prob * tf.math.exp(-log_prob)
self.sample_entropy += self.entropy_reduction(entropy)
self._inputs = tf.reshape(self.embedding(branch_id), [1, 1, -1])
mask = tf.one_hot(branch_id, self.max_layer_choice)
return tf.cast(tf.reshape(mask, [-1]), tf.bool)
def _sample_input_choice(self, mutable):
query, anchors = [], []
for label in mutable.choose_from:
if label not in self._anchors_hid:
self._anchors_hid[label] = self.lstm(self._inputs)
query.append(self.attn_anchor(self._anchors_hid[label]))
anchors.append(self._anchors_hid[label])
query = tf.concat(query, axis=0)
query = tf.tanh(query + self.attn_query(anchors[-1]))
query = self.v_attn(query)
if self.temperature is not None:
query /= self.temperature
if self.tanh_constant is not None:
query = self.tanh_constant * tf.tanh(query)
if mutable.n_chosen is None:
logit = tf.concat([-query, query], axis=1)
softmax_logit = tf.math.log(tf.nn.softmax(logit, axis=-1))
skip = tf.reshape(tf.random.categorical(softmax_logit, num_samples=1), [-1])
skip_prob = tf.math.sigmoid(logit)
kl = tf.reduce_sum(skip_prob * tf.math.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(skip, logit)
skip = tf.cast(skip, tf.float32)
inputs = tf.tensordot(skip, tf.concat(anchors, 0), 1) / (1. + tf.reduce_sum(skip))
self._inputs = tf.reshape(inputs, [1, 1, -1])
else:
assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS."
logit = tf.reshape(query, [1, -1])
softmax_logit = tf.math.log(tf.nn.softmax(logit, axis=-1))
index = tf.reshape(tf.random.categorical(softmax_logit, num_samples=1), [-1])
skip = tf.reshape(tf.one_hot(index, mutable.n_candidates), [-1])
# when the size is 1, tf does not accept tensor here, complaining the shape is wrong
# but using a numpy array seems fine
log_prob = self.cross_entropy_loss(logit, query.numpy())
self._inputs = tf.reshape(anchors[index.numpy()[0]], [1, 1, -1])
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = log_prob * tf.exp(-log_prob)
self.sample_entropy += self.entropy_reduction(entropy)
assert len(skip) == mutable.n_candidates, (skip, mutable.n_candidates, mutable.n_chosen)
return tf.cast(skip, tf.bool)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import tensorflow as tf
from tensorflow.data import Dataset
from tensorflow.keras.optimizers import Adam
from nni.nas.tensorflow.utils import AverageMeterGroup, fill_zero_grads
from .mutator import EnasMutator
logger = logging.getLogger(__name__)
log_frequency = 100
entropy_weight = 0.0001
skip_weight = 0.8
baseline_decay = 0.999
child_steps = 500
mutator_lr = 0.00035
mutator_steps = 50
mutator_steps_aggregate = 20
aux_weight = 0.4
test_arc_per_epoch = 1
class EnasTrainer:
def __init__(self, model, loss, metrics, reward_function, optimizer, batch_size, num_epochs,
dataset_train, dataset_valid):
self.model = model
self.loss = loss
self.metrics = metrics
self.reward_function = reward_function
self.optimizer = optimizer
self.batch_size = batch_size
self.num_epochs = num_epochs
x, y = dataset_train
split = int(len(x) * 0.9)
self.train_set = Dataset.from_tensor_slices((x[:split], y[:split]))
self.valid_set = Dataset.from_tensor_slices((x[split:], y[split:]))
self.test_set = Dataset.from_tensor_slices(dataset_valid)
self.mutator = EnasMutator(model)
self.mutator_optim = Adam(learning_rate=mutator_lr)
self.baseline = 0.
def train(self, validate=True):
for epoch in range(self.num_epochs):
logger.info("Epoch %d Training", epoch + 1)
self.train_one_epoch(epoch)
logger.info("Epoch %d Validating", epoch + 1)
self.validate_one_epoch(epoch)
def validate(self):
self.validate_one_epoch(-1)
def train_one_epoch(self, epoch):
train_loader, valid_loader = self._create_train_loader()
# Sample model and train
meters = AverageMeterGroup()
for step in range(1, child_steps + 1):
x, y = next(train_loader)
self.mutator.reset()
with tf.GradientTape() as tape:
logits = self.model(x, training=True)
if isinstance(logits, tuple):
logits, aux_logits = logits
aux_loss = self.loss(aux_logits, y)
else:
aux_loss = 0.
metrics = self.metrics(y, logits)
loss = self.loss(y, logits) + aux_weight * aux_loss
grads = tape.gradient(loss, self.model.trainable_weights)
grads = fill_zero_grads(grads, self.model.trainable_weights)
grads, _ = tf.clip_by_global_norm(grads, 5.0)
self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
metrics['loss'] = tf.reduce_mean(loss).numpy()
meters.update(metrics)
if log_frequency and step % log_frequency == 0:
logger.info("Model Epoch [%d/%d] Step [%d/%d] %s", epoch + 1,
self.num_epochs, step, child_steps, meters)
# Train sampler (mutator)
meters = AverageMeterGroup()
for mutator_step in range(1, mutator_steps + 1):
grads_list = []
for step in range(1, mutator_steps_aggregate + 1):
with tf.GradientTape() as tape:
x, y = next(valid_loader)
self.mutator.reset()
logits = self.model(x, training=False)
metrics = self.metrics(y, logits)
reward = self.reward_function(y, logits) + entropy_weight * self.mutator.sample_entropy
self.baseline = self.baseline * baseline_decay + reward * (1 - baseline_decay)
loss = self.mutator.sample_log_prob * (reward - self.baseline)
loss += skip_weight * self.mutator.sample_skip_penalty
meters.update({
'reward': reward,
'loss': tf.reduce_mean(loss).numpy(),
'ent': self.mutator.sample_entropy.numpy(),
'log_prob': self.mutator.sample_log_prob.numpy(),
'baseline': self.baseline,
'skip': self.mutator.sample_skip_penalty,
})
cur_step = step + (mutator_step - 1) * mutator_steps_aggregate
if log_frequency and cur_step % log_frequency == 0:
logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s", epoch + 1, self.num_epochs,
mutator_step, mutator_steps, step, mutator_steps_aggregate,
meters)
grads = tape.gradient(loss, self.mutator.trainable_weights)
grads = fill_zero_grads(grads, self.mutator.trainable_weights)
grads_list.append(grads)
total_grads = [tf.math.add_n(weight_grads) for weight_grads in zip(*grads_list)]
total_grads, _ = tf.clip_by_global_norm(total_grads, 5.0)
self.mutator_optim.apply_gradients(zip(total_grads, self.mutator.trainable_weights))
def validate_one_epoch(self, epoch):
test_loader = self._create_validate_loader()
for arc_id in range(test_arc_per_epoch):
meters = AverageMeterGroup()
for x, y in test_loader:
self.mutator.reset()
logits = self.model(x)
if isinstance(logits, tuple):
logits, _ = logits
metrics = self.metrics(logits, y)
loss = self.loss(y, logits)
metrics['loss'] = tf.reduce_mean(loss).numpy()
meters.update(metrics)
logger.info("Test Epoch [%d/%d] Arc [%d/%d] Summary %s",
epoch + 1, self.num_epochs, arc_id + 1, test_arc_per_epoch,
meters.summary())
def _create_train_loader(self):
train_set = self.train_set.shuffle(1000000).batch(self.batch_size)
test_set = self.test_set.shuffle(1000000).batch(self.batch_size)
return iter(train_set), iter(test_set)
def _create_validate_loader(self):
return iter(self.test_set.shuffle(1000000).batch(self.batch_size))
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from tensorflow.keras import Model
from .utils import global_mutable_counting
_logger = logging.getLogger(__name__)
class Mutable(Model):
def __init__(self, key=None):
super().__init__()
if key is None:
self._key = '{}_{}'.format(type(self).__name__, global_mutable_counting())
elif isinstance(key, str):
self._key = key
else:
self._key = str(key)
_logger.warning('Key "%s" is not string, converted to string.', key)
self.init_hook = None
self.forward_hook = None
def __deepcopy__(self, memodict=None):
raise NotImplementedError("Deep copy doesn't work for mutables.")
def __call__(self, *args, **kwargs):
self._check_built()
return super().__call__(*args, **kwargs)
def set_mutator(self, mutator):
if 'mutator' in self.__dict__:
raise RuntimeError('`set_mutator is called more than once. '
'Did you parse the search space multiple times? '
'Or did you apply multiple fixed architectures?')
self.__dict__['mutator'] = mutator
def call(self, *inputs):
raise NotImplementedError('Method `call` of Mutable must be overridden')
@property
def key(self):
return self._key
@property
def name(self):
return self._name if hasattr(self, '_name') else self._key
@name.setter
def name(self, name):
self._name = name
def _check_built(self):
if not hasattr(self, 'mutator'):
raise ValueError(
"Mutator not set for {}. You might have forgotten to initialize and apply your mutator. "
"Or did you initialize a mutable on the fly in forward pass? Move to `__init__` "
"so that trainer can locate all your mutables. See NNI docs for more details.".format(self))
def __repr__(self):
return '{} ({})'.format(self.name, self.key)
class MutableScope(Mutable):
def __call__(self, *args, **kwargs):
try:
self._check_built()
self.mutator.enter_mutable_scope(self)
return super().__call__(*args, **kwargs)
finally:
self.mutator.exit_mutable_scope(self)
class LayerChoice(Mutable):
def __init__(self, op_candidates, reduction='sum', return_mask=False, key=None):
super().__init__(key=key)
self.length = len(op_candidates)
self.choices = op_candidates
self.reduction = reduction
self.return_mask = return_mask
self._built = False
def call(self, *inputs):
if not self._built:
for op in self.choices:
if len(inputs) > 1: # FIXME: not tested
op.build([inp.shape for inp in inputs])
elif len(inputs) == 1:
op.build(inputs[0].shape)
self._built = True
out, mask = self.mutator.on_forward_layer_choice(self, *inputs)
if self.return_mask:
return out, mask
return out
def __len__(self):
return len(self.choices)
class InputChoice(Mutable):
NO_KEY = ''
def __init__(self, n_candidates=None, choose_from=None, n_chosen=None, reduction='sum', return_mask=False, key=None):
super().__init__(key=key)
assert n_candidates is not None or choose_from is not None, \
'At least one of `n_candidates` and `choose_from` must be not None.'
if choose_from is not None and n_candidates is None:
n_candidates = len(choose_from)
elif choose_from is None and n_candidates is not None:
choose_from = [self.NO_KEY] * n_candidates
assert n_candidates == len(choose_from), 'Number of candidates must be equal to the length of `choose_from`.'
assert n_candidates > 0, 'Number of candidates must be greater than 0.'
assert n_chosen is None or 0 <= n_chosen <= n_candidates, \
'Expected selected number must be None or no more than number of candidates.'
self.n_candidates = n_candidates
self.choose_from = choose_from.copy()
self.n_chosen = n_chosen
self.reduction = reduction
self.return_mask = return_mask
def call(self, optional_inputs):
optional_input_list = optional_inputs
if isinstance(optional_inputs, dict):
optional_input_list = [optional_inputs[tag] for tag in self.choose_from]
assert isinstance(optional_input_list, list), \
'Optional input list must be a list, not a {}.'.format(type(optional_input_list))
assert len(optional_inputs) == self.n_candidates, \
'Length of the input list must be equal to number of candidates.'
out, mask = self.mutator.on_forward_input_choice(self, optional_input_list)
if self.return_mask:
return out, mask
return out
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import tensorflow as tf
from .base_mutator import BaseMutator
_logger = logging.getLogger(__name__)
class Mutator(BaseMutator):
def __init__(self, model):
super().__init__(model)
self._cache = {}
def sample_search(self):
raise NotImplementedError('Method `sample_search` must be overridden')
def sample_final(self):
raise NotImplementedError('Method `sample_final` must be overriden for exporting')
def reset(self):
self._cache = self.sample_search()
def export(self):
return self.sample_final()
# TODO: status
# TODO: graph
def on_forward_layer_choice(self, mutable, *inputs):
mask = self._get_decision(mutable)
assert len(mask) == len(mutable), \
'Invalid mask, expected {} to be of length {}.'.format(mask, len(mutable))
out = self._select_with_mask(lambda choice: choice(*inputs), mutable.choices, mask)
return self._tensor_reduction(mutable.reduction, out), mask
def on_forward_input_choice(self, mutable, tensor_list):
mask = self._get_decision(mutable)
assert len(mask) == mutable.n_candidates, \
'Invalid mask, expected {} to be of length {}.'.format(mask, mutable.n_candidates)
out = self._select_with_mask(lambda tensor: tensor, tensor_list, mask)
return self._tensor_reduction(mutable.reduction, out), mask
def _select_with_mask(self, map_fn, candidates, mask):
if mask.dtype.is_bool:
out = [map_fn(cand) for cand, m in zip(candidates, mask) if m]
elif mask.dtype.is_floating:
out = [map_fn(cand) * m for cand, m in zip(candidates, mask) if m]
else:
raise ValueError('Unrecognized mask, dtype is {}'.format(mask.dtype.name))
return out
def _tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == 'none':
return tensor_list
if not tensor_list:
return None
if len(tensor_list) == 1:
return tensor_list[0]
if reduction_type == 'sum':
return sum(tensor_list)
if reduction_type == 'mean':
return sum(tensor_list) / len(tensor_list)
if reduction_type == 'concat':
return tf.concat(tensor_list, axis=0)
raise ValueError('Unrecognized reduction policy: "{}'.format(reduction_type))
def _get_decision(self, mutable):
if mutable.key not in self._cache:
raise ValueError('"{}" not found in decision cache.'.format(mutable.key))
result = self._cache[mutable.key]
_logger.debug('Decision %s: %s', mutable.key, result)
return result
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import tensorflow as tf
_counter = 0
def global_mutable_counting():
global _counter
_counter += 1
return _counter
class AverageMeter:
def __init__(self, name):
self.name = name
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val):
self.val = val
self.sum += val
self.count += 1
self.avg = self.sum / self.count
def __str__(self):
return '{name} {val:4f} ({avg:4f})'.format(**self.__dict__)
def summary(self):
return '{name}: {avg:4f}'.format(**self.__dict__)
class AverageMeterGroup:
def __init__(self):
self.meters = {}
def update(self, data):
for k, v in data.items():
if k not in self.meters:
self.meters[k] = AverageMeter(k)
self.meters[k].update(v)
def __str__(self):
return ' '.join(str(v) for v in self.meters.values())
def summary(self):
return ' '.join(v.summary() for v in self.meters.values())
class StructuredMutableTreeNode:
def __init__(self, mutable):
self.mutable = mutable
self.children = []
def add_child(self, mutable):
self.children.append(StructuredMutableTreeNode(mutable))
return self.children[-1]
def type(self):
return type(self.mutable)
def __iter__(self):
return self.traverse()
def traverse(self, order="pre", deduplicate=True, memo=None):
if memo is None:
memo = set()
assert order in ["pre", "post"]
if order == "pre":
if self.mutable is not None:
if not deduplicate or self.mutable.key not in memo:
memo.add(self.mutable.key)
yield self.mutable
for child in self.children:
for m in child.traverse(order=order, deduplicate=deduplicate, memo=memo):
yield m
if order == "post":
if self.mutable is not None:
if not deduplicate or self.mutable.key not in memo:
memo.add(self.mutable.key)
yield self.mutable
def fill_zero_grads(grads, weights):
ret = []
for grad, weight in zip(grads, weights):
if grad is not None:
ret.append(grad)
else:
ret.append(tf.zeros_like(weight))
return ret
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment