"docs/archive_en_US/Tutorial/HowToUseDocker.md" did not exist on "14c1b31c784e14a498e78631bb7f40f0ca3a9151"
Commit b40e3db7 authored by quzha's avatar quzha
Browse files

Merge branch 'master' of github.com:Microsoft/nni into dev-retiarii

parents efa4e31c 95f731e4
authorName: default authorName: default
experimentName: auto_rocksdb_SMAC experimentName: auto_rocksdb_SMAC
trialConcurrency: 1 trialConcurrency: 1
maxExecDuration: 12h maxExecDuration: 12h
maxTrialNum: 256 maxTrialNum: 256
#choice: local, remote, pai #choice: local, remote, pai
trainingServicePlatform: local trainingServicePlatform: local
searchSpacePath: search_space.json searchSpacePath: search_space.json
#choice: true, false #choice: true, false
useAnnotation: false useAnnotation: false
tuner: tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner #choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner
#SMAC (SMAC should be installed through nnictl) #SMAC (SMAC should be installed through nnictl)
builtinTunerName: SMAC builtinTunerName: SMAC
classArgs: classArgs:
#choice: maximize, minimize #choice: maximize, minimize
optimize_mode: maximize optimize_mode: maximize
trial: trial:
command: python3 main.py command: python3 main.py
codeDir: . codeDir: .
gpuNum: 0 gpuNum: 0
authorName: default authorName: default
experimentName: auto_rocksdb_TPE experimentName: auto_rocksdb_TPE
trialConcurrency: 1 trialConcurrency: 1
maxExecDuration: 12h maxExecDuration: 12h
maxTrialNum: 256 maxTrialNum: 256
#choice: local, remote, pai #choice: local, remote, pai
trainingServicePlatform: local trainingServicePlatform: local
searchSpacePath: search_space.json searchSpacePath: search_space.json
#choice: true, false #choice: true, false
useAnnotation: false useAnnotation: false
tuner: tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner #choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner
#SMAC (SMAC should be installed through nnictl) #SMAC (SMAC should be installed through nnictl)
builtinTunerName: TPE builtinTunerName: TPE
classArgs: classArgs:
#choice: maximize, minimize #choice: maximize, minimize
optimize_mode: maximize optimize_mode: maximize
trial: trial:
command: python3 main.py command: python3 main.py
codeDir: . codeDir: .
gpuNum: 0 gpuNum: 0
{ {
"write_buffer_size": { "write_buffer_size": {
"_type": "quniform", "_type": "quniform",
"_value": [2097152, 16777216, 1048576] "_value": [2097152, 16777216, 1048576]
}, },
"min_write_buffer_number_to_merge": { "min_write_buffer_number_to_merge": {
"_type": "quniform", "_type": "quniform",
"_value": [2, 16, 1] "_value": [2, 16, 1]
}, },
"level0_file_num_compaction_trigger": { "level0_file_num_compaction_trigger": {
"_type": "quniform", "_type": "quniform",
"_value": [2, 16, 1] "_value": [2, 16, 1]
} }
} }
...@@ -219,8 +219,7 @@ def run_epoch(batches, answer_net, is_training): ...@@ -219,8 +219,7 @@ def run_epoch(batches, answer_net, is_training):
loss, _, = sess.run( loss, _, = sess.run(
[answer_net.loss, answer_net.train_op], feed_dict=feed_dict) [answer_net.loss, answer_net.train_op], feed_dict=feed_dict)
if count % 100 == 0: if count % 100 == 0:
logger.debug('%d %g except:%g, loss:%g' % logger.debug('%d %g except:%g, loss:%g', count, used, used / count * len(batches), loss)
(count, used, used / count * len(batches), loss))
loss_sum += loss loss_sum += loss
else: else:
feed_dict = {answer_net.query_word: query, feed_dict = {answer_net.query_word: query,
...@@ -240,8 +239,7 @@ def run_epoch(batches, answer_net, is_training): ...@@ -240,8 +239,7 @@ def run_epoch(batches, answer_net, is_training):
contexts += context contexts += context
ids = np.concatenate((ids, sample_id)) ids = np.concatenate((ids, sample_id))
if count % 100 == 0: if count % 100 == 0:
logger.debug('%d %g except:%g' % logger.debug('%d %g except:%g', count, used, used / count * len(batches))
(count, used, used / count * len(batches)))
loss = loss_sum / len(batches) loss = loss_sum / len(batches)
if is_training: if is_training:
return loss return loss
...@@ -333,7 +331,7 @@ def train_with_graph(p_graph, qp_pairs, dev_qp_pairs): ...@@ -333,7 +331,7 @@ def train_with_graph(p_graph, qp_pairs, dev_qp_pairs):
train_batches = data.get_batches(qp_pairs, cfg.batch_size) train_batches = data.get_batches(qp_pairs, cfg.batch_size)
train_loss = run_epoch(train_batches, train_model, True) train_loss = run_epoch(train_batches, train_model, True)
logger.debug('epoch ' + str(epoch) + logger.debug('epoch ' + str(epoch) +
' loss: ' + str(train_loss)) ' loss: ', str(train_loss))
dev_batches = list(data.get_batches( dev_batches = list(data.get_batches(
dev_qp_pairs, cfg.batch_size)) dev_qp_pairs, cfg.batch_size))
_, position1, position2, ids, contexts = run_epoch( _, position1, position2, ids, contexts = run_epoch(
...@@ -369,8 +367,7 @@ def train_with_graph(p_graph, qp_pairs, dev_qp_pairs): ...@@ -369,8 +367,7 @@ def train_with_graph(p_graph, qp_pairs, dev_qp_pairs):
with open(os.path.join(save_path, 'epoch%d.score' % epoch), 'wb') as file: with open(os.path.join(save_path, 'epoch%d.score' % epoch), 'wb') as file:
pickle.dump( pickle.dump(
(position1, position2, ids, contexts), file) (position1, position2, ids, contexts), file)
logger.debug('epoch %d acc %g bestacc %g' % logger.debug('epoch %d acc %g bestacc %g', epoch, acc, bestacc)
(epoch, acc, bestacc))
if patience <= iter: if patience <= iter:
break break
logger.debug('save done.') logger.debug('save done.')
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import logging import logging
import torch import torch
from schema import And, Optional from schema import And, Optional
import copy
from nni.compression.pytorch.utils.config_validation import CompressorSchema from nni.compression.pytorch.utils.config_validation import CompressorSchema
from .constants import MASKER_DICT from .constants import MASKER_DICT
...@@ -53,7 +54,7 @@ class ADMMPruner(OneshotPruner): ...@@ -53,7 +54,7 @@ class ADMMPruner(OneshotPruner):
row : float row : float
Penalty parameters for ADMM training. Penalty parameters for ADMM training.
base_algo : str base_algo : str
Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops, Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune. the assigned `base_algo` is used to decide which filters/channels/weights to prune.
""" """
...@@ -87,7 +88,7 @@ class ADMMPruner(OneshotPruner): ...@@ -87,7 +88,7 @@ class ADMMPruner(OneshotPruner):
Optional('op_types'): [str], Optional('op_types'): [str],
Optional('op_names'): [str], Optional('op_names'): [str],
}], model, _logger) }], model, _logger)
elif self._base_algo in ['l1', 'l2']: elif self._base_algo in ['l1', 'l2', 'fpgm']:
schema = CompressorSchema([{ schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1), 'sparsity': And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'], 'op_types': ['Conv2d'],
...@@ -96,7 +97,7 @@ class ADMMPruner(OneshotPruner): ...@@ -96,7 +97,7 @@ class ADMMPruner(OneshotPruner):
schema.validate(config_list) schema.validate(config_list)
def _projection(self, weight, sparsity): def _projection(self, weight, sparsity, wrapper):
''' '''
Return the Euclidean projection of the weight matrix according to the pruning mode. Return the Euclidean projection of the weight matrix according to the pruning mode.
...@@ -106,31 +107,17 @@ class ADMMPruner(OneshotPruner): ...@@ -106,31 +107,17 @@ class ADMMPruner(OneshotPruner):
original matrix original matrix
sparsity : float sparsity : float
the ratio of parameters which need to be set to zero the ratio of parameters which need to be set to zero
wrapper: PrunerModuleWrapper
layer wrapper of this layer
Returns Returns
------- -------
tensor tensor
the projected matrix the projected matrix
''' '''
w_abs = weight.abs() wrapper_copy = copy.deepcopy(wrapper)
if self._base_algo == 'level': wrapper_copy.module.weight.data = weight
k = int(weight.numel() * sparsity) return weight.data.mul(self.masker.calc_mask(sparsity, wrapper_copy)['weight_mask'])
if k == 0:
mask_weight = torch.ones(weight.shape).type_as(weight)
else:
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
mask_weight = torch.gt(w_abs, threshold).type_as(weight)
elif self._base_algo in ['l1', 'l2']:
filters = weight.size(0)
num_prune = int(filters * sparsity)
if filters < 2 or num_prune < 1:
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
else:
w_abs_structured = w_abs.view(filters, -1).sum(dim=1)
threshold = torch.topk(w_abs_structured.view(-1), num_prune, largest=False)[0].max()
mask_weight = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
return weight.data.mul(mask_weight)
def compress(self): def compress(self):
""" """
...@@ -179,7 +166,7 @@ class ADMMPruner(OneshotPruner): ...@@ -179,7 +166,7 @@ class ADMMPruner(OneshotPruner):
# U_i^{k+1} = U^k + W_i^{k+1} - Z_i^{k+1} # U_i^{k+1} = U^k + W_i^{k+1} - Z_i^{k+1}
for i, wrapper in enumerate(self.get_modules_wrapper()): for i, wrapper in enumerate(self.get_modules_wrapper()):
z = wrapper.module.weight.data + U[i] z = wrapper.module.weight.data + U[i]
Z[i] = self._projection(z, wrapper.config['sparsity']) Z[i] = self._projection(z, wrapper.config['sparsity'], wrapper)
U[i] = U[i] + wrapper.module.weight.data - Z[i] U[i] = U[i] + wrapper.module.weight.data - Z[i]
# apply prune # apply prune
......
...@@ -80,7 +80,7 @@ class AutoCompressPruner(Pruner): ...@@ -80,7 +80,7 @@ class AutoCompressPruner(Pruner):
optimize_mode : str optimize_mode : str
optimize mode, `maximize` or `minimize`, by default `maximize`. optimize mode, `maximize` or `minimize`, by default `maximize`.
base_algo : str base_algo : str
Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops, Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune. the assigned `base_algo` is used to decide which filters/channels/weights to prune.
start_temperature : float start_temperature : float
Start temperature of the simulated annealing process. Start temperature of the simulated annealing process.
...@@ -151,7 +151,7 @@ class AutoCompressPruner(Pruner): ...@@ -151,7 +151,7 @@ class AutoCompressPruner(Pruner):
Optional('op_types'): [str], Optional('op_types'): [str],
Optional('op_names'): [str], Optional('op_names'): [str],
}], model, _logger) }], model, _logger)
elif self._base_algo in ['l1', 'l2']: elif self._base_algo in ['l1', 'l2', 'fpgm']:
schema = CompressorSchema([{ schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1), 'sparsity': And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'], 'op_types': ['Conv2d'],
......
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
# Licensed under the MIT license. # Licensed under the MIT license.
from .one_shot import LevelPruner, L1FilterPruner, L2FilterPruner from .one_shot import LevelPruner, L1FilterPruner, L2FilterPruner, FPGMPruner
PRUNER_DICT = { PRUNER_DICT = {
'level': LevelPruner, 'level': LevelPruner,
'l1': L1FilterPruner, 'l1': L1FilterPruner,
'l2': L2FilterPruner 'l2': L2FilterPruner,
'fpgm': FPGMPruner
} }
...@@ -73,7 +73,7 @@ class NetAdaptPruner(Pruner): ...@@ -73,7 +73,7 @@ class NetAdaptPruner(Pruner):
optimize_mode : str optimize_mode : str
optimize mode, `maximize` or `minimize`, by default `maximize`. optimize mode, `maximize` or `minimize`, by default `maximize`.
base_algo : str base_algo : str
Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops, Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune. the assigned `base_algo` is used to decide which filters/channels/weights to prune.
sparsity_per_iteration : float sparsity_per_iteration : float
sparsity to prune in each iteration. sparsity to prune in each iteration.
...@@ -125,7 +125,7 @@ class NetAdaptPruner(Pruner): ...@@ -125,7 +125,7 @@ class NetAdaptPruner(Pruner):
Optional('op_types'): [str], Optional('op_types'): [str],
Optional('op_names'): [str], Optional('op_names'): [str],
}], model, _logger) }], model, _logger)
elif self._base_algo in ['l1', 'l2']: elif self._base_algo in ['l1', 'l2', 'fpgm']:
schema = CompressorSchema([{ schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1), 'sparsity': And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'], 'op_types': ['Conv2d'],
...@@ -149,7 +149,7 @@ class NetAdaptPruner(Pruner): ...@@ -149,7 +149,7 @@ class NetAdaptPruner(Pruner):
return config_list_updated return config_list_updated
# if op_name is not in self._config_list_generated, create a new json item # if op_name is not in self._config_list_generated, create a new json item
if self._base_algo in ['l1', 'l2']: if self._base_algo in ['l1', 'l2', 'fpgm']:
config_list_updated.append( config_list_updated.append(
{'sparsity': sparsity, 'op_types': ['Conv2d'], 'op_names': [op_name]}) {'sparsity': sparsity, 'op_types': ['Conv2d'], 'op_names': [op_name]})
elif self._base_algo == 'level': elif self._base_algo == 'level':
......
...@@ -68,7 +68,7 @@ class SensitivityPruner(Pruner): ...@@ -68,7 +68,7 @@ class SensitivityPruner(Pruner):
>>> loss.backward() >>> loss.backward()
>>> optimizer.step() >>> optimizer.step()
base_algo: str base_algo: str
base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`.
sparsity_proportion_calc: function sparsity_proportion_calc: function
This function generate the sparsity proportion between the conv layers according to the This function generate the sparsity proportion between the conv layers according to the
sensitivity analysis results. We provide a default function to quantify the sparsity sensitivity analysis results. We provide a default function to quantify the sparsity
...@@ -150,7 +150,7 @@ class SensitivityPruner(Pruner): ...@@ -150,7 +150,7 @@ class SensitivityPruner(Pruner):
Optional('op_types'): [str], Optional('op_types'): [str],
Optional('op_names'): [str], Optional('op_names'): [str],
}], model, _logger) }], model, _logger)
elif self.base_algo in ['l1', 'l2']: elif self.base_algo in ['l1', 'l2', 'fpgm']:
schema = CompressorSchema([{ schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1), 'sparsity': And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'], 'op_types': ['Conv2d'],
......
...@@ -54,7 +54,7 @@ class SimulatedAnnealingPruner(Pruner): ...@@ -54,7 +54,7 @@ class SimulatedAnnealingPruner(Pruner):
optimize_mode : str optimize_mode : str
Optimize mode, `maximize` or `minimize`, by default `maximize`. Optimize mode, `maximize` or `minimize`, by default `maximize`.
base_algo : str base_algo : str
Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops, Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune. the assigned `base_algo` is used to decide which filters/channels/weights to prune.
start_temperature : float start_temperature : float
Start temperature of the simulated annealing process. Start temperature of the simulated annealing process.
...@@ -120,7 +120,7 @@ class SimulatedAnnealingPruner(Pruner): ...@@ -120,7 +120,7 @@ class SimulatedAnnealingPruner(Pruner):
Optional('op_types'): [str], Optional('op_types'): [str],
Optional('op_names'): [str], Optional('op_names'): [str],
}], model, _logger) }], model, _logger)
elif self._base_algo in ['l1', 'l2']: elif self._base_algo in ['l1', 'l2', 'fpgm']:
schema = CompressorSchema([{ schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1), 'sparsity': And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'], 'op_types': ['Conv2d'],
...@@ -152,7 +152,7 @@ class SimulatedAnnealingPruner(Pruner): ...@@ -152,7 +152,7 @@ class SimulatedAnnealingPruner(Pruner):
# a layer with more weights will have no less pruning rate # a layer with more weights will have no less pruning rate
for idx, wrapper in enumerate(self.get_modules_wrapper()): for idx, wrapper in enumerate(self.get_modules_wrapper()):
# L1Filter Pruner requires to specify op_types # L1Filter Pruner requires to specify op_types
if self._base_algo in ['l1', 'l2']: if self._base_algo in ['l1', 'l2', 'fpgm']:
config_list.append( config_list.append(
{'sparsity': sparsities[idx], 'op_types': ['Conv2d'], 'op_names': [wrapper.name]}) {'sparsity': sparsities[idx], 'op_types': ['Conv2d'], 'op_names': [wrapper.name]})
elif self._base_algo == 'level': elif self._base_algo == 'level':
......
...@@ -73,9 +73,9 @@ def update_quantization_param(bits, rmin, rmax): ...@@ -73,9 +73,9 @@ def update_quantization_param(bits, rmin, rmax):
---------- ----------
bits : int bits : int
quantization bits length quantization bits length
rmin : float rmin : Tensor
min value of real value min value of real value
rmax : float rmax : Tensor
max value of real value max value of real value
Returns Returns
...@@ -85,12 +85,17 @@ def update_quantization_param(bits, rmin, rmax): ...@@ -85,12 +85,17 @@ def update_quantization_param(bits, rmin, rmax):
# extend the [min, max] interval to ensure that it contains 0. # extend the [min, max] interval to ensure that it contains 0.
# Otherwise, we would not meet the requirement that 0 be an exactly # Otherwise, we would not meet the requirement that 0 be an exactly
# representable value. # representable value.
rmin = min(rmin, 0) if rmin.is_cuda:
rmax = max(rmax, 0) rmin = torch.min(rmin, torch.Tensor([0]).cuda())
rmax = torch.max(rmax, torch.Tensor([0]).cuda())
qmin = torch.Tensor([0]).cuda()
qmax = torch.Tensor([(1 << bits) - 1]).cuda()
else:
rmin = torch.min(rmin, torch.Tensor([0]))
rmax = torch.max(rmax, torch.Tensor([0]))
qmin = torch.Tensor([0])
qmax = torch.Tensor([(1 << bits) - 1])
# the min and max quantized values, as floating-point values
qmin = 0
qmax = (1 << bits) - 1
# First determine the scale. # First determine the scale.
scale = (rmax - rmin) / (qmax - qmin) scale = (rmax - rmin) / (qmax - qmin)
...@@ -143,11 +148,11 @@ class QAT_Quantizer(Quantizer): ...@@ -143,11 +148,11 @@ class QAT_Quantizer(Quantizer):
types of nn.module you want to apply quantization, eg. 'Conv2d' types of nn.module you want to apply quantization, eg. 'Conv2d'
""" """
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
self.steps = 1
modules_to_compress = self.get_modules_to_compress() modules_to_compress = self.get_modules_to_compress()
self.bound_model.register_buffer("steps", torch.Tensor([1]))
for layer, config in modules_to_compress: for layer, config in modules_to_compress:
layer.module.register_buffer("zero_point", None) layer.module.register_buffer("zero_point", torch.Tensor([0.0]))
layer.module.register_buffer("scale", None) layer.module.register_buffer("scale", torch.Tensor([1.0]))
if "output" in config.get("quant_types", []): if "output" in config.get("quant_types", []):
layer.module.register_buffer('ema_decay', torch.Tensor([0.99])) layer.module.register_buffer('ema_decay', torch.Tensor([0.99]))
layer.module.register_buffer('tracked_min_biased', torch.zeros(1)) layer.module.register_buffer('tracked_min_biased', torch.zeros(1))
...@@ -187,13 +192,17 @@ class QAT_Quantizer(Quantizer): ...@@ -187,13 +192,17 @@ class QAT_Quantizer(Quantizer):
quantization bits length quantization bits length
op : torch.nn.Module op : torch.nn.Module
target module target module
real_val : float real_val : Tensor
real value to be quantized real value to be quantized
Returns Returns
------- -------
float Tensor
""" """
if real_val.is_cuda:
op.zero_point = op.zero_point.cuda()
op.scale = op.scale.cuda()
transformed_val = op.zero_point + real_val / op.scale transformed_val = op.zero_point + real_val / op.scale
qmin = 0 qmin = 0
qmax = (1 << bits) - 1 qmax = (1 << bits) - 1
...@@ -229,7 +238,8 @@ class QAT_Quantizer(Quantizer): ...@@ -229,7 +238,8 @@ class QAT_Quantizer(Quantizer):
quant_start_step = config.get('quant_start_step', 0) quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1" assert weight_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > self.steps: # we dont update weight in evaluation stage
if quant_start_step > self.bound_model.steps or not wrapper.training:
return weight return weight
# if bias exists, quantize bias to uint32 # if bias exists, quantize bias to uint32
...@@ -258,15 +268,17 @@ class QAT_Quantizer(Quantizer): ...@@ -258,15 +268,17 @@ class QAT_Quantizer(Quantizer):
quant_start_step = config.get('quant_start_step', 0) quant_start_step = config.get('quant_start_step', 0)
assert output_bits >= 1, "quant bits length should be at least 1" assert output_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > self.steps: if quant_start_step > self.bound_model.steps:
return output return output
current_min, current_max = torch.min(output), torch.max(output) # we dont update output quantization parameters in evaluation stage
module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min, if wrapper.training:
module.ema_decay, self.steps) current_min, current_max = torch.min(output), torch.max(output)
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max, module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min,
module.ema_decay, self.steps) module.ema_decay, self.bound_model.steps)
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min, module.tracked_max) module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max,
module.ema_decay, self.bound_model.steps)
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min, module.tracked_max)
out = self._quantize(output_bits, module, output) out = self._quantize(output_bits, module, output)
out = self._dequantize(module, out) out = self._dequantize(module, out)
return out return out
...@@ -279,7 +291,7 @@ class QAT_Quantizer(Quantizer): ...@@ -279,7 +291,7 @@ class QAT_Quantizer(Quantizer):
""" """
override `compressor` `step` method, quantization only happens after certain number of steps override `compressor` `step` method, quantization only happens after certain number of steps
""" """
self.steps += 1 self.bound_model.steps +=1
class DoReFaQuantizer(Quantizer): class DoReFaQuantizer(Quantizer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment