Unverified Commit b9a7a95d authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #223 from microsoft/master

merge master
parents f9ee589c 0c7f22fb
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.nas.pytorch.spos import SPOSEvolution
from network import ShuffleNetV2OneShot
class EvolutionWithFlops(SPOSEvolution):
"""
This tuner extends the function of evolution tuner, by limiting the flops generated by tuner.
Needs a function to examine the flops.
"""
def __init__(self, flops_limit=330E6, **kwargs):
super().__init__(**kwargs)
self.model = ShuffleNetV2OneShot()
self.flops_limit = flops_limit
def _is_legal(self, cand):
if not super()._is_legal(cand):
return False
if self.model.get_candidate_flops(cand) > self.flops_limit:
return False
return True
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
class CrossEntropyLabelSmooth(nn.Module):
def __init__(self, num_classes, epsilon):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets):
log_probs = self.logsoftmax(inputs)
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (-targets * log_probs).mean(0).sum()
return loss
def accuracy(output, target, topk=(1, 5)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
......@@ -3,7 +3,7 @@
import logging
import torch
from .compressor import Quantizer
from .compressor import Quantizer, QuantGrad, QuantType
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer']
......@@ -240,4 +240,34 @@ class DoReFaQuantizer(Quantizer):
def quantize(self, input_ri, q_bits):
scale = pow(2, q_bits)-1
output = torch.round(input_ri*scale)/scale
return output
\ No newline at end of file
return output
class ClipGrad(QuantGrad):
@staticmethod
def quant_backward(tensor, grad_output, quant_type):
if quant_type == QuantType.QUANT_OUTPUT:
grad_output[torch.abs(tensor) > 1] = 0
return grad_output
class BNNQuantizer(Quantizer):
"""Binarized Neural Networks, as defined in:
Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1
(https://arxiv.org/abs/1602.02830)
"""
def __init__(self, model, config_list):
super().__init__(model, config_list)
self.quant_grad = ClipGrad
def quantize_weight(self, weight, config, **kwargs):
out = torch.sign(weight)
# remove zeros
out[out == 0] = 1
return out
def quantize_output(self, output, config, **kwargs):
out = torch.sign(output)
# remove zeros
out[out == 0] = 1
return out
......@@ -16,6 +16,7 @@ class LayerInfo:
self._forward = None
class Compressor:
"""
Abstract base PyTorch compressor
......@@ -193,10 +194,16 @@ class Pruner(Compressor):
layer._forward = layer.module.forward
def new_forward(*inputs):
mask = self.calc_mask(layer, config)
# apply mask to weight
old_weight = layer.module.weight.data
mask = self.calc_mask(layer, config)
layer.module.weight.data = old_weight.mul(mask)
mask_weight = mask['weight']
layer.module.weight.data = old_weight.mul(mask_weight)
# apply mask to bias
if mask.__contains__('bias') and hasattr(layer.module, 'bias') and layer.module.bias is not None:
old_bias = layer.module.bias.data
mask_bias = mask['bias']
layer.module.bias.data = old_bias.mul(mask_bias)
# calculate forward
ret = layer._forward(*inputs)
return ret
......@@ -224,12 +231,14 @@ class Pruner(Compressor):
for name, m in self.bound_model.named_modules():
if name == "":
continue
mask = self.mask_dict.get(name)
if mask is not None:
mask_sum = mask.sum().item()
mask_num = mask.numel()
masks = self.mask_dict.get(name)
if masks is not None:
mask_sum = masks['weight'].sum().item()
mask_num = masks['weight'].numel()
_logger.info('Layer: %s Sparsity: %.2f', name, 1 - mask_sum / mask_num)
m.weight.data = m.weight.data.mul(mask)
m.weight.data = m.weight.data.mul(masks['weight'])
if masks.__contains__('bias') and hasattr(m, 'bias') and m.bias is not None:
m.bias.data = m.bias.data.mul(masks['bias'])
else:
_logger.info('Layer: %s NOT compressed', name)
torch.save(self.bound_model.state_dict(), model_path)
......@@ -258,7 +267,6 @@ class Quantizer(Compressor):
"""
quantize should overload this method to quantize weight.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
weight : Tensor
......@@ -272,7 +280,6 @@ class Quantizer(Compressor):
"""
quantize should overload this method to quantize output.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
output : Tensor
......@@ -286,7 +293,6 @@ class Quantizer(Compressor):
"""
quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
inputs : Tensor
......@@ -300,7 +306,6 @@ class Quantizer(Compressor):
def _instrument_layer(self, layer, config):
"""
Create a wrapper forward function to replace the original one.
Parameters
----------
layer : LayerInfo
......@@ -365,7 +370,6 @@ class QuantGrad(torch.autograd.Function):
"""
This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator
Parameters
----------
tensor : Tensor
......@@ -375,7 +379,6 @@ class QuantGrad(torch.autograd.Function):
quant_type : QuantType
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
you can define different behavior for different types.
Returns
-------
tensor
......@@ -399,3 +402,4 @@ def _check_weight(module):
return isinstance(module.weight.data, torch.Tensor)
except AttributeError:
return False
\ No newline at end of file
......@@ -17,6 +17,7 @@ class LotteryTicketPruner(Pruner):
4. Reset the remaining parameters to their values in theta_0, creating the winning ticket f(x;m*theta_0).
5. Repeat step 2, 3, and 4.
"""
def __init__(self, model, config_list, optimizer, lr_scheduler=None, reset_weights=True):
"""
Parameters
......@@ -55,7 +56,8 @@ class LotteryTicketPruner(Pruner):
assert 'prune_iterations' in config, 'prune_iterations must exist in your config'
assert 'sparsity' in config, 'sparsity must exist in your config'
if prune_iterations is not None:
assert prune_iterations == config['prune_iterations'], 'The values of prune_iterations must be equal in your config'
assert prune_iterations == config[
'prune_iterations'], 'The values of prune_iterations must be equal in your config'
prune_iterations = config['prune_iterations']
return prune_iterations
......@@ -67,8 +69,8 @@ class LotteryTicketPruner(Pruner):
if print_mask:
print('mask: ', mask)
# calculate current sparsity
mask_num = mask.sum().item()
mask_size = mask.numel()
mask_num = mask['weight'].sum().item()
mask_size = mask['weight'].numel()
print('sparsity: ', 1 - mask_num / mask_size)
torch.set_printoptions(profile='default')
......@@ -84,11 +86,11 @@ class LotteryTicketPruner(Pruner):
curr_sparsity = self._calc_sparsity(sparsity)
assert self.mask_dict.get(op_name) is not None
curr_mask = self.mask_dict.get(op_name)
w_abs = weight.abs() * curr_mask
w_abs = weight.abs() * curr_mask['weight']
k = int(w_abs.numel() * curr_sparsity)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
mask = torch.gt(w_abs, threshold).type_as(weight)
return mask
return {'weight': mask}
def calc_mask(self, layer, config):
"""
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .evolution import SPOSEvolution
from .mutator import SPOSSupernetTrainingMutator
from .trainer import SPOSSupernetTrainer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
import os
import re
from collections import deque
import numpy as np
from nni.tuner import Tuner
from nni.nas.pytorch.classic_nas.mutator import LAYER_CHOICE, INPUT_CHOICE
_logger = logging.getLogger(__name__)
class SPOSEvolution(Tuner):
def __init__(self, max_epochs=20, num_select=10, num_population=50, m_prob=0.1,
num_crossover=25, num_mutation=25):
"""
Initialize SPOS Evolution Tuner.
Parameters
----------
max_epochs : int
Maximum number of epochs to run.
num_select : int
Number of survival candidates of each epoch.
num_population : int
Number of candidates at the start of each epoch. If candidates generated by
crossover and mutation are not enough, the rest will be filled with random
candidates.
m_prob : float
The probability of mutation.
num_crossover : int
Number of candidates generated by crossover in each epoch.
num_mutation : int
Number of candidates generated by mutation in each epoch.
"""
assert num_population >= num_select
self.max_epochs = max_epochs
self.num_select = num_select
self.num_population = num_population
self.m_prob = m_prob
self.num_crossover = num_crossover
self.num_mutation = num_mutation
self.epoch = 0
self.candidates = []
self.search_space = None
self.random_state = np.random.RandomState(0)
# async status
self._to_evaluate_queue = deque()
self._sending_parameter_queue = deque()
self._pending_result_ids = set()
self._reward_dict = dict()
self._id2candidate = dict()
self._st_callback = None
def update_search_space(self, search_space):
"""
Handle the initialization/update event of search space.
"""
self._search_space = search_space
self._next_round()
def _next_round(self):
_logger.info("Epoch %d, generating...", self.epoch)
if self.epoch == 0:
self._get_random_population()
self.export_results(self.candidates)
else:
best_candidates = self._select_top_candidates()
self.export_results(best_candidates)
if self.epoch >= self.max_epochs:
return
self.candidates = self._get_mutation(best_candidates) + self._get_crossover(best_candidates)
self._get_random_population()
self.epoch += 1
def _random_candidate(self):
chosen_arch = dict()
for key, val in self._search_space.items():
if val["_type"] == LAYER_CHOICE:
choices = val["_value"]
index = self.random_state.randint(len(choices))
chosen_arch[key] = {"_value": choices[index], "_idx": index}
elif val["_type"] == INPUT_CHOICE:
raise NotImplementedError("Input choice is not implemented yet.")
return chosen_arch
def _add_to_evaluate_queue(self, cand):
_logger.info("Generate candidate %s, adding to eval queue.", self._get_architecture_repr(cand))
self._reward_dict[self._hashcode(cand)] = 0.
self._to_evaluate_queue.append(cand)
def _get_random_population(self):
while len(self.candidates) < self.num_population:
cand = self._random_candidate()
if self._is_legal(cand):
_logger.info("Random candidate generated.")
self._add_to_evaluate_queue(cand)
self.candidates.append(cand)
def _get_crossover(self, best):
result = []
for _ in range(10 * self.num_crossover):
cand_p1 = best[self.random_state.randint(len(best))]
cand_p2 = best[self.random_state.randint(len(best))]
assert cand_p1.keys() == cand_p2.keys()
cand = {k: cand_p1[k] if self.random_state.randint(2) == 0 else cand_p2[k]
for k in cand_p1.keys()}
if self._is_legal(cand):
result.append(cand)
self._add_to_evaluate_queue(cand)
if len(result) >= self.num_crossover:
break
_logger.info("Found %d architectures with crossover.", len(result))
return result
def _get_mutation(self, best):
result = []
for _ in range(10 * self.num_mutation):
cand = best[self.random_state.randint(len(best))].copy()
mutation_sample = np.random.random_sample(len(cand))
for s, k in zip(mutation_sample, cand):
if s < self.m_prob:
choices = self._search_space[k]["_value"]
index = self.random_state.randint(len(choices))
cand[k] = {"_value": choices[index], "_idx": index}
if self._is_legal(cand):
result.append(cand)
self._add_to_evaluate_queue(cand)
if len(result) >= self.num_mutation:
break
_logger.info("Found %d architectures with mutation.", len(result))
return result
def _get_architecture_repr(self, cand):
return re.sub(r"\".*?\": \{\"_idx\": (\d+), \"_value\": \".*?\"\}", r"\1",
self._hashcode(cand))
def _is_legal(self, cand):
if self._hashcode(cand) in self._reward_dict:
return False
return True
def _select_top_candidates(self):
reward_query = lambda cand: self._reward_dict[self._hashcode(cand)]
_logger.info("All candidate rewards: %s", list(map(reward_query, self.candidates)))
result = sorted(self.candidates, key=reward_query, reverse=True)[:self.num_select]
_logger.info("Best candidate rewards: %s", list(map(reward_query, result)))
return result
@staticmethod
def _hashcode(d):
return json.dumps(d, sort_keys=True)
def _bind_and_send_parameters(self):
"""
There are two types of resources: parameter ids and candidates. This function is called at
necessary times to bind these resources to send new trials with st_callback.
"""
result = []
while self._sending_parameter_queue and self._to_evaluate_queue:
parameter_id = self._sending_parameter_queue.popleft()
parameters = self._to_evaluate_queue.popleft()
self._id2candidate[parameter_id] = parameters
result.append(parameters)
self._pending_result_ids.add(parameter_id)
self._st_callback(parameter_id, parameters)
_logger.info("Send parameter [%d] %s.", parameter_id, self._get_architecture_repr(parameters))
return result
def generate_multiple_parameters(self, parameter_id_list, **kwargs):
"""
Callback function necessary to implement a tuner. This will put more parameter ids into the
parameter id queue.
"""
if "st_callback" in kwargs and self._st_callback is None:
self._st_callback = kwargs["st_callback"]
for parameter_id in parameter_id_list:
self._sending_parameter_queue.append(parameter_id)
self._bind_and_send_parameters()
return [] # always not use this. might induce problem of over-sending
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""
Callback function. Receive a trial result.
"""
_logger.info("Candidate %d, reported reward %f", parameter_id, value)
self._reward_dict[self._hashcode(self._id2candidate[parameter_id])] = value
def trial_end(self, parameter_id, success, **kwargs):
"""
Callback function when a trial is ended and resource is released.
"""
self._pending_result_ids.remove(parameter_id)
if not self._pending_result_ids and not self._to_evaluate_queue:
# a new epoch now
self._next_round()
assert self._st_callback is not None
self._bind_and_send_parameters()
def export_results(self, result):
"""
Export a number of candidates to `checkpoints` dir.
Parameters
----------
result : dict
"""
os.makedirs("checkpoints", exist_ok=True)
for i, cand in enumerate(result):
converted = dict()
for cand_key, cand_val in cand.items():
onehot = [k == cand_val["_idx"] for k in range(len(self._search_space[cand_key]["_value"]))]
converted[cand_key] = onehot
with open(os.path.join("checkpoints", "%03d_%03d.json" % (self.epoch, i)), "w") as fp:
json.dump(converted, fp)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import numpy as np
from nni.nas.pytorch.random import RandomMutator
_logger = logging.getLogger(__name__)
class SPOSSupernetTrainingMutator(RandomMutator):
def __init__(self, model, flops_func=None, flops_lb=None, flops_ub=None,
flops_bin_num=7, flops_sample_timeout=500):
"""
Parameters
----------
model : nn.Module
flops_func : callable
Callable that takes a candidate from `sample_search` and returns its candidate. When `flops_func`
is None, functions related to flops will be deactivated.
flops_lb : number
Lower bound of flops.
flops_ub : number
Upper bound of flops.
flops_bin_num : number
Number of bins divided for the interval of flops to ensure the uniformity. Bigger number will be more
uniform, but the sampling will be slower.
flops_sample_timeout : int
Maximum number of attempts to sample before giving up and use a random candidate.
"""
super().__init__(model)
self._flops_func = flops_func
if self._flops_func is not None:
self._flops_bin_num = flops_bin_num
self._flops_bins = [flops_lb + (flops_ub - flops_lb) / flops_bin_num * i for i in range(flops_bin_num + 1)]
self._flops_sample_timeout = flops_sample_timeout
def sample_search(self):
"""
Sample a candidate for training. When `flops_func` is not None, candidates will be sampled uniformly
relative to flops.
Returns
-------
dict
"""
if self._flops_func is not None:
for times in range(self._flops_sample_timeout):
idx = np.random.randint(self._flops_bin_num)
cand = super().sample_search()
if self._flops_bins[idx] <= self._flops_func(cand) <= self._flops_bins[idx + 1]:
_logger.debug("Sampled candidate flops %f in %d times.", cand, times)
return cand
_logger.warning("Failed to sample a flops-valid candidate within %d tries.", self._flops_sample_timeout)
return super().sample_search()
def sample_final(self):
"""
Implement only to suffice the interface of Mutator.
"""
return self.sample_search()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup
from .mutator import SPOSSupernetTrainingMutator
logger = logging.getLogger(__name__)
class SPOSSupernetTrainer(Trainer):
"""
This trainer trains a supernet that can be used for evolution search.
"""
def __init__(self, model, loss, metrics,
optimizer, num_epochs, train_loader, valid_loader,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
callbacks=None):
assert torch.cuda.is_available()
super().__init__(model, mutator if mutator is not None else SPOSSupernetTrainingMutator(model),
loss, metrics, optimizer, num_epochs, None, None,
batch_size, workers, device, log_frequency, callbacks)
self.train_loader = train_loader
self.valid_loader = valid_loader
def train_one_epoch(self, epoch):
self.model.train()
meters = AverageMeterGroup()
for step, (x, y) in enumerate(self.train_loader):
self.optimizer.zero_grad()
self.mutator.reset()
logits = self.model(x)
loss = self.loss(logits, y)
loss.backward()
self.optimizer.step()
metrics = self.metrics(logits, y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
def validate_one_epoch(self, epoch):
self.model.eval()
meters = AverageMeterGroup()
with torch.no_grad():
for step, (x, y) in enumerate(self.valid_loader):
self.mutator.reset()
logits = self.model(x)
loss = self.loss(logits, y)
metrics = self.metrics(logits, y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.valid_loader), meters)
......@@ -136,12 +136,12 @@ class CompressorTestCase(TestCase):
model.conv2.weight.data = torch.tensor(w).float()
layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
masks = pruner.calc_mask(layer, config_list[0])
assert all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
pruner.update_epoch(1)
model.conv2.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(layer, config_list[1])
assert all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
@tf2
def test_tf_fpgm_pruner(self):
......@@ -190,8 +190,8 @@ class CompressorTestCase(TestCase):
mask1 = pruner.calc_mask(layer1, config_list[0])
layer2 = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
mask2 = pruner.calc_mask(layer2, config_list[1])
assert all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))
assert all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))
assert all(torch.sum(mask1['weight'], (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))
assert all(torch.sum(mask2['weight'], (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))
def test_torch_slim_pruner(self):
"""
......@@ -218,8 +218,10 @@ class CompressorTestCase(TestCase):
mask1 = pruner.calc_mask(layer1, config_list[0])
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0])
assert all(mask1.numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask2.numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask1['weight'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask2['weight'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask1['bias'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask2['bias'].numpy() == np.array([0., 1., 1., 1., 1.]))
config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}]
model.bn1.weight.data = torch.tensor(w).float()
......@@ -230,8 +232,10 @@ class CompressorTestCase(TestCase):
mask1 = pruner.calc_mask(layer1, config_list[0])
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0])
assert all(mask1.numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask2.numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask1['weight'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask2['weight'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask1['bias'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask2['bias'].numpy() == np.array([0., 0., 0., 1., 1.]))
def test_torch_QAT_quantizer(self):
model = TorchModel()
......
......@@ -65,7 +65,7 @@ tuner_schema_dict = {
'builtinTunerName': 'SMAC',
Optional('classArgs'): {
'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'),
'config_dedup': setType('config_dedup', bool)
Optional('config_dedup'): setType('config_dedup', bool)
},
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment