Unverified Commit cd3a912a authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #218 from microsoft/master

merge master
parents a0846f2a e9cba778
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .trial import *
from .smartparam import *
......
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
'''
__main__.py
'''
......
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Assessor analyzes trial's intermediate results (e.g., periodically evaluated accuracy on test dataset)
......
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
batch_tuner.py including:
class BatchTuner
......
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
'''
bohb_advisor.py
'''
......
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from datetime import datetime
from io import TextIOBase
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .compressor import LayerInfo, Compressor, Pruner, Quantizer
from .builtin_pruners import *
from .builtin_quantizers import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import numpy as np
import tensorflow as tf
......@@ -155,18 +158,18 @@ class FPGMPruner(Pruner):
return self.mask_dict.get(layer.name)
try:
weight = tf.stop_gradient(tf.transpose(weight, [2, 3, 0, 1]))
masks = np.ones(weight.shape)
num_kernels = weight.shape[0] * weight.shape[1]
num_prune = int(num_kernels * config.get('sparsity'))
if num_kernels < 2 or num_prune < 1:
w = tf.stop_gradient(tf.transpose(tf.reshape(weight, (-1, weight.shape[-1])), [1, 0]))
masks = np.ones(w.shape)
num_filters = w.shape[0]
num_prune = int(num_filters * config.get('sparsity'))
if num_filters < 2 or num_prune < 1:
return masks
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
min_gm_idx = self._get_min_gm_kernel_idx(w, num_prune)
for idx in min_gm_idx:
masks[tuple(idx)] = 0.
masks[idx] = 0.
finally:
masks = np.transpose(masks, [2, 3, 0, 1])
masks = tf.reshape(tf.transpose(masks, [1, 0]), weight.shape)
masks = tf.Variable(masks)
self.mask_dict.update({op_name: masks})
self.epoch_pruned_layers.add(layer.name)
......@@ -174,22 +177,17 @@ class FPGMPruner(Pruner):
return masks
def _get_min_gm_kernel_idx(self, weight, n):
assert len(weight.shape) >= 3
assert weight.shape[0] * weight.shape[1] > 2
dist_list = []
for in_i in range(weight.shape[0]):
for out_i in range(weight.shape[1]):
dist_sum = self._get_distance_sum(weight, in_i, out_i)
dist_list.append((dist_sum, (in_i, out_i)))
for out_i in range(weight.shape[0]):
dist_sum = self._get_distance_sum(weight, out_i)
dist_list.append((dist_sum, out_i))
min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n]
return [x[1] for x in min_gm_kernels]
def _get_distance_sum(self, weight, in_idx, out_idx):
w = tf.reshape(weight, (-1, weight.shape[-2], weight.shape[-1]))
anchor_w = tf.tile(tf.expand_dims(weight[in_idx, out_idx], 0), [w.shape[0], 1, 1])
x = w - anchor_w
x = tf.math.reduce_sum((x*x), (-2, -1))
def _get_distance_sum(self, weight, out_idx):
anchor_w = tf.tile(tf.expand_dims(weight[out_idx], 0), [weight.shape[0], 1])
x = weight - anchor_w
x = tf.math.reduce_sum((x*x), -1)
x = tf.math.sqrt(x)
return tf.math.reduce_sum(x)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import tensorflow as tf
from .compressor import Quantizer
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import tensorflow as tf
from . import default_layers
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from tensorflow import keras
supported_layers = {
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .compressor import LayerInfo, Compressor, Pruner, Quantizer
from .builtin_pruners import *
from .builtin_quantizers import *
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from .compressor import Pruner
......@@ -47,7 +50,7 @@ class LevelPruner(Pruner):
k = int(weight.numel() * config['sparsity'])
if k == 0:
return torch.ones(weight.shape).type_as(weight)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_dict.update({op_name: mask})
self.if_init_list.update({op_name: False})
......@@ -108,7 +111,7 @@ class AGP_Pruner(Pruner):
return mask
# if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_dict.update({op_name: new_mask})
self.if_init_list.update({op_name: False})
......@@ -215,9 +218,9 @@ class FPGMPruner(Pruner):
masks = torch.ones(weight.size()).type_as(weight)
try:
num_kernels = weight.size(0) * weight.size(1)
num_prune = int(num_kernels * config.get('sparsity'))
if num_kernels < 2 or num_prune < 1:
num_filters = weight.size(0)
num_prune = int(num_filters * config.get('sparsity'))
if num_filters < 2 or num_prune < 1:
return masks
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
for idx in min_gm_idx:
......@@ -233,13 +236,12 @@ class FPGMPruner(Pruner):
dist_list = []
for out_i in range(weight.size(0)):
for in_i in range(weight.size(1)):
dist_sum = self._get_distance_sum(weight, out_i, in_i)
dist_list.append((dist_sum, (out_i, in_i)))
dist_sum = self._get_distance_sum(weight, out_i)
dist_list.append((dist_sum, out_i))
min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n]
return [x[1] for x in min_gm_kernels]
def _get_distance_sum(self, weight, out_idx, in_idx):
def _get_distance_sum(self, weight, out_idx):
"""
Calculate the total distance between a specified filter (by out_idex and in_idx) and
all other filters.
......@@ -257,24 +259,18 @@ class FPGMPruner(Pruner):
out_idx: int
output channel index of specified filter, this method calculates the total distance
between this specified filter and all other filters.
in_idx: int
input channel index of specified filter
Returns
-------
float32
The total distance
"""
logger.debug('weight size: %s', weight.size())
if len(weight.size()) == 4: # Conv2d
w = weight.view(-1, weight.size(-2), weight.size(-1))
anchor_w = weight[out_idx, in_idx].unsqueeze(0).expand(w.size(0), w.size(1), w.size(2))
elif len(weight.size()) == 3: # Conv1d
w = weight.view(-1, weight.size(-1))
anchor_w = weight[out_idx, in_idx].unsqueeze(0).expand(w.size(0), w.size(1))
else:
raise RuntimeError('unsupported layer type')
assert len(weight.size()) in [3, 4], 'unsupported weight shape'
w = weight.view(weight.size(0), -1)
anchor_w = w[out_idx].unsqueeze(0).expand(w.size(0), w.size(1))
x = w - anchor_w
x = (x * x).sum((-2, -1))
x = (x * x).sum(-1)
x = torch.sqrt(x)
return x.sum()
......@@ -336,7 +332,7 @@ class L1FilterPruner(Pruner):
if k == 0:
return torch.ones(weight.shape).type_as(weight)
w_abs_structured = w_abs.view(filters, -1).sum(dim=1)
threshold = torch.topk(w_abs_structured.view(-1), k, largest=False).values.max()
threshold = torch.topk(w_abs_structured.view(-1), k, largest=False)[0].max()
mask = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
finally:
self.mask_dict.update({layer.name: mask})
......@@ -370,10 +366,10 @@ class SlimPruner(Pruner):
config = config_list[0]
for (layer, config) in self.detect_modules_to_compress():
assert layer.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
weight_list.append(layer.module.weight.data.clone())
weight_list.append(layer.module.weight.data.abs().clone())
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False).values.max()
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
def calc_mask(self, layer, config):
"""
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from .compressor import Quantizer
......@@ -22,6 +25,80 @@ class NaiveQuantizer(Quantizer):
return weight.div(scale).type(torch.int8).type(orig_type).mul(scale)
def update_ema(biased_ema, value, decay, step):
"""
calculate biased stat and unbiased stat in each step using exponential moving average method
Parameters
----------
biased_ema : float
previous stat value
value : float
current stat value
decay : float
the weight of previous stat value, larger means smoother curve
step : int
current step
Returns
-------
float, float
"""
biased_ema = biased_ema * decay + (1 - decay) * value
unbiased_ema = biased_ema / (1 - decay ** step) # Bias correction
return biased_ema, unbiased_ema
def update_quantization_param(bits, rmin, rmax):
"""
calculate the `zero_point` and `scale`.
Parameters
----------
bits : int
quantization bits length
rmin : float
min value of real value
rmax : float
max value of real value
Returns
-------
float, float
"""
# extend the [min, max] interval to ensure that it contains 0.
# Otherwise, we would not meet the requirement that 0 be an exactly
# representable value.
rmin = min(rmin, 0)
rmax = max(rmax, 0)
# the min and max quantized values, as floating-point values
qmin = 0
qmax = (1 << bits) - 1
# First determine the scale.
scale = (rmax - rmin) / (qmax - qmin)
# Zero-point computation.
initial_zero_point = qmin - rmin / scale
# Now we need to nudge the zero point to be an integer
nudged_zero_point = 0
if initial_zero_point < qmin:
nudged_zero_point = qmin
elif initial_zero_point > qmax:
nudged_zero_point = qmax
else:
nudged_zero_point = torch.round(initial_zero_point)
return scale, nudged_zero_point
def get_bits_length(config, quant_type):
if isinstance(config["quant_bits"], int):
return config["quant_bits"]
else:
return config["quant_bits"].get(quant_type)
class QAT_Quantizer(Quantizer):
"""Quantizer using the DoReFa scheme, as defined in:
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
......@@ -29,23 +106,119 @@ class QAT_Quantizer(Quantizer):
"""
def __init__(self, model, config_list):
"""
config_list: supported keys:
- q_bits
Parameters
----------
layer : LayerInfo
the layer to quantize
config_list : list of dict
list of configurations for quantization
supported keys for dict:
- quant_types : list of string
type of quantization you want to apply, currently support 'weight', 'input', 'output'
- quant_bits : int or dict of {str : int}
bits length of quantization, key is the quantization type, value is the length, eg. {'weight', 8},
when the type is int, all quantization types share same bits length
- quant_start_step : int
disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0
- op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d'
"""
super().__init__(model, config_list)
self.steps = 1
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
layer.module.register_buffer("zero_point", None)
layer.module.register_buffer("scale", None)
if "output" in config.get("quant_types", []):
layer.module.register_buffer('ema_decay', torch.Tensor([0.99]))
layer.module.register_buffer('tracked_min_biased', torch.zeros(1))
layer.module.register_buffer('tracked_min', torch.zeros(1))
layer.module.register_buffer('tracked_max_biased', torch.zeros(1))
layer.module.register_buffer('tracked_max', torch.zeros(1))
def quantize_weight(self, weight, config, **kwargs):
if config['q_bits'] <= 1:
def _quantize(self, bits, op, real_val):
"""
quantize real value.
Parameters
----------
bits : int
quantization bits length
op : torch.nn.module
target module
real_val : float
real value to be quantized
Returns
-------
float
"""
transformed_val = op.zero_point + real_val / op.scale
qmin = 0
qmax = (1 << bits) - 1
clamped_val = torch.clamp(transformed_val, qmin, qmax)
quantized_val = torch.round(clamped_val)
return quantized_val
def _dequantize(self, op, quantized_val):
"""
dequantize quantized value.
Because we simulate quantization in training process, all the computations still happen as float point computations, which means we
first quantize tensors then dequantize them. For more details, please refer to the paper.
Parameters
----------
op : torch.nn.Module
target module
quantized_val : float
quantized_val value to be dequantized
Returns
-------
float
"""
real_val = op.scale * (quantized_val - op.zero_point)
return real_val
def quantize_weight(self, weight, config, op, **kwargs):
weight_bits = get_bits_length(config, 'weight')
quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > self.steps:
return weight
a = torch.min(weight)
b = torch.max(weight)
n = pow(2, config['q_bits'])
scale = (b-a)/(n-1)
zero_point = a
out = torch.round((weight - zero_point)/scale)
out = out*scale + zero_point
orig_type = weight.dtype
return out.type(orig_type)
rmin, rmax = torch.min(weight), torch.max(weight)
op.scale, op.zero_point = update_quantization_param(weight_bits, rmin, rmax)
out = self._quantize(weight_bits, op, weight)
out = self._dequantize(op, out)
return out
def quantize_output(self, output, config, op, **kwargs):
output_bits = get_bits_length(config, 'output')
quant_start_step = config.get('quant_start_step', 0)
assert output_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > self.steps:
return output
current_min, current_max = torch.min(output), torch.max(output)
op.tracked_min_biased, op.tracked_min = update_ema(op.tracked_min_biased, current_min, op.ema_decay, self.steps)
op.tracked_max_biased, op.tracked_max = update_ema(op.tracked_max_biased, current_max, op.ema_decay, self.steps)
op.scale, op.zero_point = update_quantization_param(output_bits, op.tracked_min, op.tracked_max)
out = self._quantize(output_bits, op, output)
out = self._dequantize(op, out)
return out
def fold_bn(self, config, **kwargs):
# TODO simulate folded weight
pass
def step(self):
"""
override `compressor` `step` method, quantization only happens after certain number of steps
"""
self.steps += 1
class DoReFaQuantizer(Quantizer):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from . import default_layers
......@@ -304,6 +307,12 @@ class Quantizer(Compressor):
assert layer._forward is None, 'Each model can only be compressed once'
assert "quant_types" in config, 'must provide quant_types in config'
assert isinstance(config["quant_types"], list), 'quant_types must be list type'
assert "quant_bits" in config, 'must provide quant_bits in config'
assert isinstance(config["quant_bits"], int) or isinstance(config["quant_bits"], dict), 'quant_bits must be dict type or int type'
if isinstance(config["quant_bits"], dict):
for quant_type in config["quant_types"]:
assert quant_type in config["quant_bits"], 'bits length for %s must be specified in quant_bits dict' % quant_type
if 'weight' in config["quant_types"]:
if not _check_weight(layer.module):
......@@ -312,7 +321,7 @@ class Quantizer(Compressor):
def new_forward(*inputs):
if 'input' in config["quant_types"]:
inputs = self.quantize_input(inputs, config=config, op=layer.module, op_type=layer.type, op_name=layer.name)
inputs = straight_through_quantize_input.apply(inputs, self, config, layer)
if 'weight' in config["quant_types"] and _check_weight(layer.module):
weight = layer.module.weight.data
......@@ -324,12 +333,32 @@ class Quantizer(Compressor):
result = layer._forward(*inputs)
if 'output' in config["quant_types"]:
result = self.quantize_output(result, config, op=layer.module, op_type=layer.type, op_name=layer.name)
result = straight_through_quantize_output.apply(result, self, config, layer)
return result
layer.module.forward = new_forward
class straight_through_quantize_output(torch.autograd.Function):
@staticmethod
def forward(ctx, output, quantizer, config, layer):
return quantizer.quantize_output(output, config, op=layer.module, op_type=layer.type, op_name=layer.name)
@staticmethod
def backward(ctx, grad_output):
# Straight-through estimator
return grad_output, None, None, None
class straight_through_quantize_input(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, quantizer, config, layer):
return quantizer.quantize_input(inputs, config, op=layer.module, op_type=layer.type, op_name=layer.name)
@staticmethod
def backward(ctx, grad_output):
# Straight-through estimator
return grad_output, None, None, None
def _check_weight(module):
try:
return isinstance(module.weight, torch.nn.Parameter) and isinstance(module.weight.data, torch.Tensor)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
weighted_modules = [
'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d',
'Linear', 'Bilinear',
......
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
ModuleName = {
'TPE': 'nni.hyperopt_tuner.hyperopt_tuner',
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .curvefitting_assessor import CurvefittingAssessor
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import datetime
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment