Unverified Commit 0a742aff authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #249 from microsoft/master

merge master
parents 0fd38deb 76c819c0
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
__version__ = '999.0.0-developing'
from .env_vars import dispatcher_env_vars from .env_vars import dispatcher_env_vars
if dispatcher_env_vars.SDK_PROCESS != 'dispatcher': if dispatcher_env_vars.SDK_PROCESS != 'dispatcher':
......
...@@ -163,9 +163,11 @@ class ModelSpeedup: ...@@ -163,9 +163,11 @@ class ModelSpeedup:
first, do mask/shape inference, first, do mask/shape inference,
second, replace modules second, replace modules
""" """
training = self.bound_model.training
_logger.info("start to speed up the model") _logger.info("start to speed up the model")
_logger.info("infer module masks...") _logger.info("infer module masks...")
self.infer_modules_masks() self.infer_modules_masks()
_logger.info("replace compressed modules...") _logger.info("replace compressed modules...")
self.replace_compressed_modules() self.replace_compressed_modules()
self.bound_model.train(training)
_logger.info("speedup done") _logger.info("speedup done")
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from tensorflow.keras import Model
from .mutables import Mutable, MutableScope, InputChoice
from .utils import StructuredMutableTreeNode
class BaseMutator(Model):
def __init__(self, model):
super().__init__()
self.__dict__['model'] = model
self._structured_mutables = self._parse_search_space(self.model)
def _parse_search_space(self, module, root=None, prefix='', memo=None, nested_detection=None):
if memo is None:
memo = set()
if root is None:
root = StructuredMutableTreeNode(None)
if module not in memo:
memo.add(module)
if isinstance(module, Mutable):
if nested_detection is not None:
raise RuntimeError('Cannot have nested search space. Error at {} in {}'
.format(module, nested_detection))
module.name = prefix
module.set_mutator(self)
root = root.add_child(module)
if not isinstance(module, MutableScope):
nested_detection = module
if isinstance(module, InputChoice):
for k in module.choose_from:
if k != InputChoice.NO_KEY and k not in [m.key for m in memo if isinstance(m, Mutable)]:
raise RuntimeError('"{}" required by "{}" not found in keys that appeared before, and is not NO_KEY.'
.format(k, module.key))
for submodule in module.layers:
if not isinstance(submodule, Model):
continue
submodule_prefix = prefix + ('.' if prefix else '') + submodule.name
self._parse_search_space(submodule, root, submodule_prefix, memo=memo, nested_detection=nested_detection)
return root
@property
def mutables(self):
return self._structured_mutables
def undedup_mutables(self):
return self._structured_mutables.traverse(deduplicate=False)
def call(self, *inputs):
raise RuntimeError('Call is undefined for mutators.')
def __setattr__(self, name, value):
if name == 'model':
raise AttributeError("Attribute `model` can be set at most once, and you shouldn't use `self.model = model` to "
"include your network, as it will include all parameters in model into the mutator.")
return super().__setattr__(name, value)
def enter_mutable_scope(self, mutable_scope):
pass
def exit_mutable_scope(self, mutable_scope):
pass
def on_forward_layer_choice(self, mutable, *inputs):
raise NotImplementedError
def on_forward_input_choice(self, mutable, tensor_list):
raise NotImplementedError
def export(self):
raise NotImplementedError
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import EnasMutator
from .trainer import EnasTrainer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import tensorflow as tf
from tensorflow.keras.layers import Dense, Embedding, LSTMCell, RNN
from tensorflow.keras.losses import SparseCategoricalCrossentropy, Reduction
from nni.nas.tensorflow.mutator import Mutator
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice, MutableScope
class EnasMutator(Mutator):
def __init__(self, model,
lstm_size=64,
lstm_num_layers=1,
tanh_constant=1.5,
cell_exit_extra_step=False,
skip_target=0.4,
temperature=None,
branch_bias=0.25,
entropy_reduction='sum'):
super().__init__(model)
self.tanh_constant = tanh_constant
self.temperature = temperature
self.cell_exit_extra_step = cell_exit_extra_step
cells = [LSTMCell(units=lstm_size, use_bias=False) for _ in range(lstm_num_layers)]
self.lstm = RNN(cells, stateful=True)
self.g_emb = tf.random.normal((1, 1, lstm_size)) * 0.1
self.skip_targets = tf.constant([1.0 - skip_target, skip_target])
self.max_layer_choice = 0
self.bias_dict = {}
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
if self.max_layer_choice == 0:
self.max_layer_choice = len(mutable)
assert self.max_layer_choice == len(mutable), \
"ENAS mutator requires all layer choice have the same number of candidates."
if 'reduce' in mutable.key:
bias = []
for choice in mutable.choices:
if 'conv' in str(type(choice)).lower():
bias.append(branch_bias)
else:
bias.append(-branch_bias)
self.bias_dict[mutable.key] = tf.constant(bias)
# exposed for trainer
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
# internal nn layers
self.embedding = Embedding(self.max_layer_choice + 1, lstm_size)
self.soft = Dense(self.max_layer_choice, use_bias=False)
self.attn_anchor = Dense(lstm_size, use_bias=False)
self.attn_query = Dense(lstm_size, use_bias=False)
self.v_attn = Dense(1, use_bias=False)
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
self.entropy_reduction = tf.reduce_sum if entropy_reduction == 'sum' else tf.reduce_mean
self.cross_entropy_loss = SparseCategoricalCrossentropy(from_logits=True, reduction=Reduction.NONE)
self._first_sample = True
def sample_search(self):
self._initialize()
self._sample(self.mutables)
self._first_sample = False
return self._choices
def sample_final(self):
return self.sample_search()
def _sample(self, tree):
mutable = tree.mutable
if isinstance(mutable, LayerChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_layer_choice(mutable)
elif isinstance(mutable, InputChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_input_choice(mutable)
for child in tree.children:
self._sample(child)
if self.cell_exit_extra_step and isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid:
self._anchors_hid[mutable.key] = self.lstm(self._inputs, 1)
def _initialize(self):
self._choices = {}
self._anchors_hid = {}
self._inputs = self.g_emb
# seems the `input_shape` parameter of RNN does not work
# workaround it by omitting `reset_states` for first run
if not self._first_sample:
self.lstm.reset_states()
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
def _sample_layer_choice(self, mutable):
logit = self.soft(self.lstm(self._inputs))
if self.temperature is not None:
logit /= self.temperature
if self.tanh_constant is not None:
logit = self.tanh_constant * tf.tanh(logit)
if mutable.key in self.bias_dict:
logit += self.bias_dict[mutable.key]
softmax_logit = tf.math.log(tf.nn.softmax(logit, axis=-1))
branch_id = tf.reshape(tf.random.categorical(softmax_logit, num_samples=1), [1])
log_prob = self.cross_entropy_loss(branch_id, logit)
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = log_prob * tf.math.exp(-log_prob)
self.sample_entropy += self.entropy_reduction(entropy)
self._inputs = tf.reshape(self.embedding(branch_id), [1, 1, -1])
mask = tf.one_hot(branch_id, self.max_layer_choice)
return tf.cast(tf.reshape(mask, [-1]), tf.bool)
def _sample_input_choice(self, mutable):
query, anchors = [], []
for label in mutable.choose_from:
if label not in self._anchors_hid:
self._anchors_hid[label] = self.lstm(self._inputs)
query.append(self.attn_anchor(self._anchors_hid[label]))
anchors.append(self._anchors_hid[label])
query = tf.concat(query, axis=0)
query = tf.tanh(query + self.attn_query(anchors[-1]))
query = self.v_attn(query)
if self.temperature is not None:
query /= self.temperature
if self.tanh_constant is not None:
query = self.tanh_constant * tf.tanh(query)
if mutable.n_chosen is None:
logit = tf.concat([-query, query], axis=1)
softmax_logit = tf.math.log(tf.nn.softmax(logit, axis=-1))
skip = tf.reshape(tf.random.categorical(softmax_logit, num_samples=1), [-1])
skip_prob = tf.math.sigmoid(logit)
kl = tf.reduce_sum(skip_prob * tf.math.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(skip, logit)
skip = tf.cast(skip, tf.float32)
inputs = tf.tensordot(skip, tf.concat(anchors, 0), 1) / (1. + tf.reduce_sum(skip))
self._inputs = tf.reshape(inputs, [1, 1, -1])
else:
assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS."
logit = tf.reshape(query, [1, -1])
softmax_logit = tf.math.log(tf.nn.softmax(logit, axis=-1))
index = tf.reshape(tf.random.categorical(softmax_logit, num_samples=1), [-1])
skip = tf.reshape(tf.one_hot(index, mutable.n_candidates), [-1])
# when the size is 1, tf does not accept tensor here, complaining the shape is wrong
# but using a numpy array seems fine
log_prob = self.cross_entropy_loss(logit, query.numpy())
self._inputs = tf.reshape(anchors[index.numpy()[0]], [1, 1, -1])
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = log_prob * tf.exp(-log_prob)
self.sample_entropy += self.entropy_reduction(entropy)
assert len(skip) == mutable.n_candidates, (skip, mutable.n_candidates, mutable.n_chosen)
return tf.cast(skip, tf.bool)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import tensorflow as tf
from tensorflow.data import Dataset
from tensorflow.keras.optimizers import Adam
from nni.nas.tensorflow.utils import AverageMeterGroup, fill_zero_grads
from .mutator import EnasMutator
logger = logging.getLogger(__name__)
log_frequency = 100
entropy_weight = 0.0001
skip_weight = 0.8
baseline_decay = 0.999
child_steps = 500
mutator_lr = 0.00035
mutator_steps = 50
mutator_steps_aggregate = 20
aux_weight = 0.4
test_arc_per_epoch = 1
class EnasTrainer:
def __init__(self, model, loss, metrics, reward_function, optimizer, batch_size, num_epochs,
dataset_train, dataset_valid):
self.model = model
self.loss = loss
self.metrics = metrics
self.reward_function = reward_function
self.optimizer = optimizer
self.batch_size = batch_size
self.num_epochs = num_epochs
x, y = dataset_train
split = int(len(x) * 0.9)
self.train_set = Dataset.from_tensor_slices((x[:split], y[:split]))
self.valid_set = Dataset.from_tensor_slices((x[split:], y[split:]))
self.test_set = Dataset.from_tensor_slices(dataset_valid)
self.mutator = EnasMutator(model)
self.mutator_optim = Adam(learning_rate=mutator_lr)
self.baseline = 0.
def train(self, validate=True):
for epoch in range(self.num_epochs):
logger.info("Epoch %d Training", epoch + 1)
self.train_one_epoch(epoch)
logger.info("Epoch %d Validating", epoch + 1)
self.validate_one_epoch(epoch)
def validate(self):
self.validate_one_epoch(-1)
def train_one_epoch(self, epoch):
train_loader, valid_loader = self._create_train_loader()
# Sample model and train
meters = AverageMeterGroup()
for step in range(1, child_steps + 1):
x, y = next(train_loader)
self.mutator.reset()
with tf.GradientTape() as tape:
logits = self.model(x, training=True)
if isinstance(logits, tuple):
logits, aux_logits = logits
aux_loss = self.loss(aux_logits, y)
else:
aux_loss = 0.
metrics = self.metrics(y, logits)
loss = self.loss(y, logits) + aux_weight * aux_loss
grads = tape.gradient(loss, self.model.trainable_weights)
grads = fill_zero_grads(grads, self.model.trainable_weights)
grads, _ = tf.clip_by_global_norm(grads, 5.0)
self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
metrics['loss'] = tf.reduce_mean(loss).numpy()
meters.update(metrics)
if log_frequency and step % log_frequency == 0:
logger.info("Model Epoch [%d/%d] Step [%d/%d] %s", epoch + 1,
self.num_epochs, step, child_steps, meters)
# Train sampler (mutator)
meters = AverageMeterGroup()
for mutator_step in range(1, mutator_steps + 1):
grads_list = []
for step in range(1, mutator_steps_aggregate + 1):
with tf.GradientTape() as tape:
x, y = next(valid_loader)
self.mutator.reset()
logits = self.model(x, training=False)
metrics = self.metrics(y, logits)
reward = self.reward_function(y, logits) + entropy_weight * self.mutator.sample_entropy
self.baseline = self.baseline * baseline_decay + reward * (1 - baseline_decay)
loss = self.mutator.sample_log_prob * (reward - self.baseline)
loss += skip_weight * self.mutator.sample_skip_penalty
meters.update({
'reward': reward,
'loss': tf.reduce_mean(loss).numpy(),
'ent': self.mutator.sample_entropy.numpy(),
'log_prob': self.mutator.sample_log_prob.numpy(),
'baseline': self.baseline,
'skip': self.mutator.sample_skip_penalty,
})
cur_step = step + (mutator_step - 1) * mutator_steps_aggregate
if log_frequency and cur_step % log_frequency == 0:
logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s", epoch + 1, self.num_epochs,
mutator_step, mutator_steps, step, mutator_steps_aggregate,
meters)
grads = tape.gradient(loss, self.mutator.trainable_weights)
grads = fill_zero_grads(grads, self.mutator.trainable_weights)
grads_list.append(grads)
total_grads = [tf.math.add_n(weight_grads) for weight_grads in zip(*grads_list)]
total_grads, _ = tf.clip_by_global_norm(total_grads, 5.0)
self.mutator_optim.apply_gradients(zip(total_grads, self.mutator.trainable_weights))
def validate_one_epoch(self, epoch):
test_loader = self._create_validate_loader()
for arc_id in range(test_arc_per_epoch):
meters = AverageMeterGroup()
for x, y in test_loader:
self.mutator.reset()
logits = self.model(x)
if isinstance(logits, tuple):
logits, _ = logits
metrics = self.metrics(logits, y)
loss = self.loss(y, logits)
metrics['loss'] = tf.reduce_mean(loss).numpy()
meters.update(metrics)
logger.info("Test Epoch [%d/%d] Arc [%d/%d] Summary %s",
epoch + 1, self.num_epochs, arc_id + 1, test_arc_per_epoch,
meters.summary())
def _create_train_loader(self):
train_set = self.train_set.shuffle(1000000).batch(self.batch_size)
test_set = self.test_set.shuffle(1000000).batch(self.batch_size)
return iter(train_set), iter(test_set)
def _create_validate_loader(self):
return iter(self.test_set.shuffle(1000000).batch(self.batch_size))
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from tensorflow.keras import Model
from .utils import global_mutable_counting
_logger = logging.getLogger(__name__)
class Mutable(Model):
def __init__(self, key=None):
super().__init__()
if key is None:
self._key = '{}_{}'.format(type(self).__name__, global_mutable_counting())
elif isinstance(key, str):
self._key = key
else:
self._key = str(key)
_logger.warning('Key "%s" is not string, converted to string.', key)
self.init_hook = None
self.forward_hook = None
def __deepcopy__(self, memodict=None):
raise NotImplementedError("Deep copy doesn't work for mutables.")
def __call__(self, *args, **kwargs):
self._check_built()
return super().__call__(*args, **kwargs)
def set_mutator(self, mutator):
if 'mutator' in self.__dict__:
raise RuntimeError('`set_mutator is called more than once. '
'Did you parse the search space multiple times? '
'Or did you apply multiple fixed architectures?')
self.__dict__['mutator'] = mutator
def call(self, *inputs):
raise NotImplementedError('Method `call` of Mutable must be overridden')
@property
def key(self):
return self._key
@property
def name(self):
return self._name if hasattr(self, '_name') else self._key
@name.setter
def name(self, name):
self._name = name
def _check_built(self):
if not hasattr(self, 'mutator'):
raise ValueError(
"Mutator not set for {}. You might have forgotten to initialize and apply your mutator. "
"Or did you initialize a mutable on the fly in forward pass? Move to `__init__` "
"so that trainer can locate all your mutables. See NNI docs for more details.".format(self))
def __repr__(self):
return '{} ({})'.format(self.name, self.key)
class MutableScope(Mutable):
def __call__(self, *args, **kwargs):
try:
self._check_built()
self.mutator.enter_mutable_scope(self)
return super().__call__(*args, **kwargs)
finally:
self.mutator.exit_mutable_scope(self)
class LayerChoice(Mutable):
def __init__(self, op_candidates, reduction='sum', return_mask=False, key=None):
super().__init__(key=key)
self.length = len(op_candidates)
self.choices = op_candidates
self.reduction = reduction
self.return_mask = return_mask
self._built = False
def call(self, *inputs):
if not self._built:
for op in self.choices:
if len(inputs) > 1: # FIXME: not tested
op.build([inp.shape for inp in inputs])
elif len(inputs) == 1:
op.build(inputs[0].shape)
self._built = True
out, mask = self.mutator.on_forward_layer_choice(self, *inputs)
if self.return_mask:
return out, mask
return out
def __len__(self):
return len(self.choices)
class InputChoice(Mutable):
NO_KEY = ''
def __init__(self, n_candidates=None, choose_from=None, n_chosen=None, reduction='sum', return_mask=False, key=None):
super().__init__(key=key)
assert n_candidates is not None or choose_from is not None, \
'At least one of `n_candidates` and `choose_from` must be not None.'
if choose_from is not None and n_candidates is None:
n_candidates = len(choose_from)
elif choose_from is None and n_candidates is not None:
choose_from = [self.NO_KEY] * n_candidates
assert n_candidates == len(choose_from), 'Number of candidates must be equal to the length of `choose_from`.'
assert n_candidates > 0, 'Number of candidates must be greater than 0.'
assert n_chosen is None or 0 <= n_chosen <= n_candidates, \
'Expected selected number must be None or no more than number of candidates.'
self.n_candidates = n_candidates
self.choose_from = choose_from.copy()
self.n_chosen = n_chosen
self.reduction = reduction
self.return_mask = return_mask
def call(self, optional_inputs):
optional_input_list = optional_inputs
if isinstance(optional_inputs, dict):
optional_input_list = [optional_inputs[tag] for tag in self.choose_from]
assert isinstance(optional_input_list, list), \
'Optional input list must be a list, not a {}.'.format(type(optional_input_list))
assert len(optional_inputs) == self.n_candidates, \
'Length of the input list must be equal to number of candidates.'
out, mask = self.mutator.on_forward_input_choice(self, optional_input_list)
if self.return_mask:
return out, mask
return out
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import tensorflow as tf
from .base_mutator import BaseMutator
_logger = logging.getLogger(__name__)
class Mutator(BaseMutator):
def __init__(self, model):
super().__init__(model)
self._cache = {}
def sample_search(self):
raise NotImplementedError('Method `sample_search` must be overridden')
def sample_final(self):
raise NotImplementedError('Method `sample_final` must be overriden for exporting')
def reset(self):
self._cache = self.sample_search()
def export(self):
return self.sample_final()
# TODO: status
# TODO: graph
def on_forward_layer_choice(self, mutable, *inputs):
mask = self._get_decision(mutable)
assert len(mask) == len(mutable), \
'Invalid mask, expected {} to be of length {}.'.format(mask, len(mutable))
out = self._select_with_mask(lambda choice: choice(*inputs), mutable.choices, mask)
return self._tensor_reduction(mutable.reduction, out), mask
def on_forward_input_choice(self, mutable, tensor_list):
mask = self._get_decision(mutable)
assert len(mask) == mutable.n_candidates, \
'Invalid mask, expected {} to be of length {}.'.format(mask, mutable.n_candidates)
out = self._select_with_mask(lambda tensor: tensor, tensor_list, mask)
return self._tensor_reduction(mutable.reduction, out), mask
def _select_with_mask(self, map_fn, candidates, mask):
if mask.dtype.is_bool:
out = [map_fn(cand) for cand, m in zip(candidates, mask) if m]
elif mask.dtype.is_floating:
out = [map_fn(cand) * m for cand, m in zip(candidates, mask) if m]
else:
raise ValueError('Unrecognized mask, dtype is {}'.format(mask.dtype.name))
return out
def _tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == 'none':
return tensor_list
if not tensor_list:
return None
if len(tensor_list) == 1:
return tensor_list[0]
if reduction_type == 'sum':
return sum(tensor_list)
if reduction_type == 'mean':
return sum(tensor_list) / len(tensor_list)
if reduction_type == 'concat':
return tf.concat(tensor_list, axis=0)
raise ValueError('Unrecognized reduction policy: "{}'.format(reduction_type))
def _get_decision(self, mutable):
if mutable.key not in self._cache:
raise ValueError('"{}" not found in decision cache.'.format(mutable.key))
result = self._cache[mutable.key]
_logger.debug('Decision %s: %s', mutable.key, result)
return result
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import tensorflow as tf
_counter = 0
def global_mutable_counting():
global _counter
_counter += 1
return _counter
class AverageMeter:
def __init__(self, name):
self.name = name
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val):
self.val = val
self.sum += val
self.count += 1
self.avg = self.sum / self.count
def __str__(self):
return '{name} {val:4f} ({avg:4f})'.format(**self.__dict__)
def summary(self):
return '{name}: {avg:4f}'.format(**self.__dict__)
class AverageMeterGroup:
def __init__(self):
self.meters = {}
def update(self, data):
for k, v in data.items():
if k not in self.meters:
self.meters[k] = AverageMeter(k)
self.meters[k].update(v)
def __str__(self):
return ' '.join(str(v) for v in self.meters.values())
def summary(self):
return ' '.join(v.summary() for v in self.meters.values())
class StructuredMutableTreeNode:
def __init__(self, mutable):
self.mutable = mutable
self.children = []
def add_child(self, mutable):
self.children.append(StructuredMutableTreeNode(mutable))
return self.children[-1]
def type(self):
return type(self.mutable)
def __iter__(self):
return self.traverse()
def traverse(self, order="pre", deduplicate=True, memo=None):
if memo is None:
memo = set()
assert order in ["pre", "post"]
if order == "pre":
if self.mutable is not None:
if not deduplicate or self.mutable.key not in memo:
memo.add(self.mutable.key)
yield self.mutable
for child in self.children:
for m in child.traverse(order=order, deduplicate=deduplicate, memo=memo):
yield m
if order == "post":
if self.mutable is not None:
if not deduplicate or self.mutable.key not in memo:
memo.add(self.mutable.key)
yield self.mutable
def fill_zero_grads(grads, weights):
ret = []
for grad, weight in zip(grads, weights):
if grad is not None:
ret.append(grad)
else:
ret.append(tf.zeros_like(weight))
return ret
...@@ -61,7 +61,7 @@ class PdType: ...@@ -61,7 +61,7 @@ class PdType:
class CategoricalPd(Pd): class CategoricalPd(Pd):
""" """
Categorical prossibility distribution Categorical probability distribution
""" """
def __init__(self, logits, mask_npinf, nsteps, size, is_act_model): def __init__(self, logits, mask_npinf, nsteps, size, is_act_model):
self.logits = logits self.logits = logits
......
...@@ -10,9 +10,11 @@ from torchvision.models.vgg import vgg16 ...@@ -10,9 +10,11 @@ from torchvision.models.vgg import vgg16
from torchvision.models.resnet import resnet18 from torchvision.models.resnet import resnet18
from unittest import TestCase, main from unittest import TestCase, main
from nni.compression.torch import L1FilterPruner from nni.compression.torch import L1FilterPruner, apply_compression_results
from nni.compression.speedup.torch import ModelSpeedup from nni.compression.speedup.torch import ModelSpeedup
torch.manual_seed(0)
class BackboneModel1(nn.Module): class BackboneModel1(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -58,7 +60,10 @@ class BigModel(torch.nn.Module): ...@@ -58,7 +60,10 @@ class BigModel(torch.nn.Module):
x = self.fc3(x) x = self.fc3(x)
return x return x
dummy_input = torch.randn(2, 1, 28, 28)
SPARSITY = 0.5 SPARSITY = 0.5
MODEL_FILE, MASK_FILE = './11_model.pth', './l1_mask.pth'
def prune_model_l1(model): def prune_model_l1(model):
config_list = [{ config_list = [{
'sparsity': SPARSITY, 'sparsity': SPARSITY,
...@@ -66,14 +71,14 @@ def prune_model_l1(model): ...@@ -66,14 +71,14 @@ def prune_model_l1(model):
}] }]
pruner = L1FilterPruner(model, config_list) pruner = L1FilterPruner(model, config_list)
pruner.compress() pruner.compress()
pruner.export_model(model_path='./11_model.pth', mask_path='./l1_mask.pth') pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE)
class SpeedupTestCase(TestCase): class SpeedupTestCase(TestCase):
def test_speedup_vgg16(self): def test_speedup_vgg16(self):
prune_model_l1(vgg16()) prune_model_l1(vgg16())
model = vgg16() model = vgg16()
model.train() model.train()
ms = ModelSpeedup(model, torch.randn(2, 3, 32, 32), './l1_mask.pth') ms = ModelSpeedup(model, torch.randn(2, 3, 32, 32), MASK_FILE)
ms.speedup_model() ms.speedup_model()
orig_model = vgg16() orig_model = vgg16()
...@@ -88,20 +93,33 @@ class SpeedupTestCase(TestCase): ...@@ -88,20 +93,33 @@ class SpeedupTestCase(TestCase):
def test_speedup_bigmodel(self): def test_speedup_bigmodel(self):
prune_model_l1(BigModel()) prune_model_l1(BigModel())
model = BigModel() model = BigModel()
apply_compression_results(model, MASK_FILE, 'cpu')
model.eval()
mask_out = model(dummy_input)
model.train() model.train()
ms = ModelSpeedup(model, torch.randn(2, 1, 28, 28), './l1_mask.pth') ms = ModelSpeedup(model, dummy_input, MASK_FILE)
ms.speedup_model() ms.speedup_model()
assert model.training
model.eval()
speedup_out = model(dummy_input)
if not torch.allclose(mask_out, speedup_out, atol=1e-07):
print('input:', dummy_input.size(), torch.abs(dummy_input).sum((2,3)))
print('mask_out:', mask_out)
print('speedup_out:', speedup_out)
raise RuntimeError('model speedup inference result is incorrect!')
orig_model = BigModel() orig_model = BigModel()
assert model.training
assert model.backbone2.conv1.out_channels == int(orig_model.backbone2.conv1.out_channels * SPARSITY) assert model.backbone2.conv1.out_channels == int(orig_model.backbone2.conv1.out_channels * SPARSITY)
assert model.backbone2.conv2.in_channels == int(orig_model.backbone2.conv2.in_channels * SPARSITY) assert model.backbone2.conv2.in_channels == int(orig_model.backbone2.conv2.in_channels * SPARSITY)
assert model.backbone2.conv2.out_channels == int(orig_model.backbone2.conv2.out_channels * SPARSITY) assert model.backbone2.conv2.out_channels == int(orig_model.backbone2.conv2.out_channels * SPARSITY)
assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY) assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY)
def tearDown(self): def tearDown(self):
os.remove('./11_model.pth') os.remove(MODEL_FILE)
os.remove('./l1_mask.pth') os.remove(MASK_FILE)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -168,6 +168,7 @@ def launch_test(config_file, training_service, test_case_config): ...@@ -168,6 +168,7 @@ def launch_test(config_file, training_service, test_case_config):
trial_stats = get_trial_stats(TRIAL_JOBS_URL) trial_stats = get_trial_stats(TRIAL_JOBS_URL)
print(json.dumps(trial_stats, indent=4), flush=True) print(json.dumps(trial_stats, indent=4), flush=True)
if status != 'DONE' or trial_stats['SUCCEEDED'] + trial_stats['EARLY_STOPPED'] < max_trial_num: if status != 'DONE' or trial_stats['SUCCEEDED'] + trial_stats['EARLY_STOPPED'] < max_trial_num:
print_experiment_log(experiment_id=experiment_id)
print_trial_job_log(training_service, TRIAL_JOBS_URL) print_trial_job_log(training_service, TRIAL_JOBS_URL)
raise AssertionError('Failed to finish in maxExecDuration') raise AssertionError('Failed to finish in maxExecDuration')
......
jobs: jobs:
- job: 'integration_test_remote' - job: 'integration_test_remote_linux_to_linux'
timeoutInMinutes: 120 timeoutInMinutes: 120
steps: steps:
......
jobs: jobs:
- job: 'integration_test_remote_windows' - job: 'integration_test_remote_windows_to_linux'
timeoutInMinutes: 120 timeoutInMinutes: 120
steps: steps:
...@@ -23,6 +23,7 @@ jobs: ...@@ -23,6 +23,7 @@ jobs:
sshEndpoint: $(end_point) sshEndpoint: $(end_point)
runOptions: inline runOptions: inline
inline: cd /tmp/nnitest/$(Build.BuildId)/nni-remote/deployment/pypi;make build inline: cd /tmp/nnitest/$(Build.BuildId)/nni-remote/deployment/pypi;make build
failOnStdErr: false
continueOnError: true continueOnError: true
displayName: 'build nni bdsit_wheel' displayName: 'build nni bdsit_wheel'
- task: SSH@0 - task: SSH@0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment