Unverified Commit 070df4a0 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Merge pull request #4291 from microsoft/v2.5

merge v2.5 back to master
parents 821706b8 6a082fe9
......@@ -83,16 +83,16 @@ def quantization_aware_training_example(train_loader, test_loader, device):
model = NaiveModel()
configure_list = [{
'quant_types': ['weight', 'output'],
'quant_bits': {'weight':8, 'output':8},
'quant_types': ['input', 'weight'],
'quant_bits': {'input':8, 'weight':8},
'op_names': ['conv1']
}, {
'quant_types': ['output'],
'quant_bits': {'output':8},
'op_names': ['relu1']
}, {
'quant_types': ['weight', 'output'],
'quant_bits': {'weight':8, 'output':8},
'quant_types': ['input', 'weight'],
'quant_bits': {'input':8, 'weight':8},
'op_names': ['conv2']
}, {
'quant_types': ['output'],
......
......@@ -2,7 +2,7 @@
# download automlbenchmark repository
if [ ! -d './automlbenchmark' ] ; then
git clone https://github.com/openml/automlbenchmark.git --branch stable --depth 1
git clone https://github.com/openml/automlbenchmark.git --branch v1.6 --depth 1
fi
# install dependencies
......
......@@ -384,6 +384,7 @@ class ADMMPruner(IterativePruner):
for i, wrapper in enumerate(self.get_modules_wrapper()):
z = wrapper.module.weight.data + self.U[i]
self.Z[i] = self._projection(z, wrapper.config['sparsity'], wrapper)
torch.cuda.empty_cache()
self.U[i] = self.U[i] + wrapper.module.weight.data - self.Z[i]
# apply prune
......
......@@ -26,7 +26,6 @@ __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer',
logger = logging.getLogger(__name__)
class NaiveQuantizer(Quantizer):
"""quantize weight to 8 bits
"""
......@@ -676,17 +675,20 @@ class QAT_Quantizer(Quantizer):
for layer, _ in modules_to_compress:
name, module = layer.name, layer.module
if name not in calibration_config:
if hasattr(module, 'weight_bits') or hasattr(module, 'output_bits') or hasattr(module, 'input_bits'):
if module.layer_quant_setting.weight or module.layer_quant_setting.input or module.layer_quant_setting.output:
logger.warning(f"Can not find module {name}'s parameter in input config.")
continue
if hasattr(module, 'weight_bits'):
assert calibration_config[name]['weight_bits'] == module.weight_bits, f"weight bits of module {name} fail to match"
if hasattr(module, 'input_bits'):
assert calibration_config[name]['input_bits'] == module.input_bits, f"input bits of module {name} fail to match"
if module.layer_quant_setting.weight:
assert calibration_config[name]['weight_bits'] == module.layer_quant_setting.weight.bits, \
f"weight bits of module {name} fail to match"
if module.layer_quant_setting.input:
assert calibration_config[name]['input_bits'] == module.layer_quant_setting.input.bits, \
f"input bits of module {name} fail to match"
module.tracked_min_input.data = torch.tensor([calibration_config[name]['tracked_min_input']])
module.tracked_max_input.data = torch.tensor([calibration_config[name]['tracked_max_input']])
if hasattr(module, 'output_bits'):
assert calibration_config[name]['output_bits'] == module.output_bits, f"output bits of module {name} fail to match"
if module.layer_quant_setting.output:
assert calibration_config[name]['output_bits'] == module.layer_quant_setting.output.bits, \
f"output bits of module {name} fail to match"
module.tracked_min_output.data = torch.tensor([calibration_config[name]['tracked_min_output']])
module.tracked_max_output.data = torch.tensor([calibration_config[name]['tracked_max_output']])
......@@ -716,11 +718,13 @@ class QAT_Quantizer(Quantizer):
self._unwrap_model()
calibration_config = {}
for name, module in self.bound_model.named_modules():
if hasattr(module, 'weight_bits') or hasattr(module, 'output_bits'):
modules_to_compress = self.get_modules_to_compress()
for layer, _ in modules_to_compress:
name, module = layer.name, layer.module
if hasattr(module.layer_quant_setting, 'weight') or hasattr(module.layer_quant_setting, 'output'):
calibration_config[name] = {}
if hasattr(module, 'weight_bits'):
calibration_config[name]['weight_bits'] = int(module.weight_bits)
if module.layer_quant_setting.weight:
calibration_config[name]['weight_bits'] = int(module.layer_quant_setting.weight.bits)
calibration_config[name]['weight_scale'] = module.weight_scale
calibration_config[name]['weight_zero_point'] = module.weight_zero_point
......@@ -738,13 +742,14 @@ class QAT_Quantizer(Quantizer):
module.register_parameter('bias', actual_bias)
else:
setattr(module, 'bias', None)
if hasattr(module, 'input_bits'):
calibration_config[name]['input_bits'] = int(module.input_bits)
if module.layer_quant_setting.input:
calibration_config[name]['input_bits'] = int(module.layer_quant_setting.input.bits)
calibration_config[name]['tracked_min_input'] = float(module.tracked_min_input)
calibration_config[name]['tracked_max_input'] = float(module.tracked_max_input)
if hasattr(module, 'output_bits'):
calibration_config[name]['output_bits'] = int(module.output_bits)
if module.layer_quant_setting.output:
calibration_config[name]['output_bits'] = int(module.layer_quant_setting.output.bits)
calibration_config[name]['tracked_min_output'] = float(module.tracked_min_output)
calibration_config[name]['tracked_max_output'] = float(module.tracked_max_output)
self._del_simulated_attr(module)
......@@ -1157,7 +1162,7 @@ class LsqQuantizer(Quantizer):
calibration_config = {}
for name, module in self.bound_model.named_modules():
if hasattr(module, 'input_bits') or hasattr(module, 'output_bits'):
if hasattr(module, 'input_bits') or hasattr(module, 'weight_bits') or hasattr(module, 'output_bits'):
calibration_config[name] = {}
if hasattr(module, 'weight_bits'):
calibration_config[name]['weight_bits'] = int(module.weight_bits)
......@@ -1177,6 +1182,11 @@ class LsqQuantizer(Quantizer):
module.register_parameter('bias', actual_bias)
else:
setattr(module, 'bias', None)
if hasattr(module, 'input_bits'):
calibration_config[name]['input_bits'] = int(module.input_bits)
abs_max_input = float(module.input_scale * module.input_qmax)
calibration_config[name]['tracked_min_input'] = -abs_max_input
calibration_config[name]['tracked_max_input'] = abs_max_input
if hasattr(module, 'output_bits'):
calibration_config[name]['output_bits'] = int(module.output_bits)
abs_max_output = float(module.output_scale * module.output_qmax)
......
......@@ -14,9 +14,6 @@ from .tools import TaskGenerator
class PruningScheduler(BasePruningScheduler):
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: Callable[[Module], None] = None,
speed_up: bool = False, dummy_input: Tensor = None, evaluator: Optional[Callable[[Module], float]] = None,
reset_weight: bool = False):
"""
Parameters
----------
......@@ -37,6 +34,9 @@ class PruningScheduler(BasePruningScheduler):
reset_weight
If set True, the model weight will reset to the origin model weight at the end of each iteration step.
"""
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: Callable[[Module], None] = None,
speed_up: bool = False, dummy_input: Tensor = None, evaluator: Optional[Callable[[Module], float]] = None,
reset_weight: bool = False):
self.pruner = pruner
self.task_generator = task_generator
self.finetuner = finetuner
......
......@@ -80,7 +80,7 @@ class GlobalSparsityAllocator(SparsityAllocator):
stay_metric = torch.topk(metric.view(-1), stay_metric_num, largest=False)[0]
sub_thresholds[name] = stay_metric.max()
if expend_times > 1:
stay_metric = stay_metric.expand(stay_metric_num, int(layer_weight_num / metric.numel())).view(-1)
stay_metric = stay_metric.expand(int(layer_weight_num / metric.numel()), stay_metric_num).contiguous().view(-1)
metric_list.append(stay_metric)
total_prune_num = int(total_sparsity * total_weight_num)
......
......@@ -79,5 +79,5 @@ def get_quant_shape(shape, quant_type, quant_scheme):
if is_per_channel(quant_scheme):
quant_shape = [1 if idx != default_idx else s for idx, s in enumerate(shape)]
else:
quant_shape = []
quant_shape = [1]
return quant_shape
......@@ -110,6 +110,8 @@ def replace_prelu(prelu, masks):
in_mask = in_masks[0]
weight_mask = weight_mask['weight']
if weight_mask.size(0) == 1:
return prelu
pruned_in, remained_in = convert_to_coarse_mask(in_mask, 1)
pruned_out, remained_out = convert_to_coarse_mask(output_mask, 1)
n_remained_in = weight_mask.size(0) - pruned_in.size(0)
......@@ -221,6 +223,7 @@ def replace_batchnorm1d(norm, masks):
affine=norm.affine,
track_running_stats=norm.track_running_stats)
# assign weights
if norm.affine:
new_norm.weight.data = torch.index_select(norm.weight.data, 0, remained_in)
new_norm.bias.data = torch.index_select(norm.bias.data, 0, remained_in)
......@@ -264,6 +267,7 @@ def replace_batchnorm2d(norm, masks):
affine=norm.affine,
track_running_stats=norm.track_running_stats)
# assign weights
if norm.affine:
new_norm.weight.data = torch.index_select(norm.weight.data, 0, remained_in)
new_norm.bias.data = torch.index_select(norm.bias.data, 0, remained_in)
......
......@@ -23,11 +23,7 @@ _logger.setLevel(logging.INFO)
class ModelSpeedup:
"""
This class is to speedup the model with provided weight mask.
"""
def __init__(self, model, dummy_input, masks_file, map_location=None,
batch_dim=0, confidence=8):
"""
Parameters
----------
model : pytorch model
......@@ -45,6 +41,9 @@ class ModelSpeedup:
confidence: the confidence coefficient of the sparsity inference. This value is
actually used as the batchsize of the dummy_input.
"""
def __init__(self, model, dummy_input, masks_file, map_location=None,
batch_dim=0, confidence=8):
assert confidence > 1
# The auto inference will change the values of the parameters in the model
# so we need make a copy before the mask inference
......
import os
ENV_NASBENCHMARK_DIR = 'NASBENCHMARK_DIR'
ENV_NNI_HOME = 'NNI_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
......@@ -10,7 +11,7 @@ def _get_nasbenchmark_dir():
nni_home = os.path.expanduser(
os.getenv(ENV_NNI_HOME,
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'nni')))
return os.path.join(nni_home, 'nasbenchmark')
return os.getenv(ENV_NASBENCHMARK_DIR, os.path.join(nni_home, 'nasbenchmark'))
DATABASE_DIR = _get_nasbenchmark_dir()
......
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser('NAS benchmark downloader')
parser.add_argument('benchmark_name', choices=['nasbench101', 'nasbench201', 'nds'])
args = parser.parse_args()
from .utils import download_benchmark
download_benchmark(args.benchmark_name)
......@@ -381,17 +381,8 @@ class GraphConverter:
# step #1: generate graph ir for this method
method_ir_graph = Graph(model=ir_model, graph_id=-100, name='temp_graph', _internal=True)
method_node_index = self.handle_graph_nodes(script_module, script_method.graph, module,
self.handle_graph_nodes(script_module, script_method.graph, module,
module_name, ir_model, method_ir_graph, shared_module_index)
for _output in script_method.graph.outputs():
method_ir_graph._add_output(_convert_name(_output.debugName()))
predecessor_node_outputs = [o for o in _output.node().outputs()]
if len(predecessor_node_outputs) == 1:
src_node_idx = None
else:
src_node_idx = predecessor_node_outputs.index(_output)
method_ir_graph.add_edge(head=(method_node_index[_output.node()], src_node_idx),
tail=(method_ir_graph.output_node, None))
self.refine_graph(method_ir_graph)
# step #2: merge this graph to its module graph
......@@ -491,18 +482,24 @@ class GraphConverter:
for node in sm_graph.nodes():
handle_single_node(node)
if node_index == {}:
# here is an example that the ir_graph is empty
if node_index != {}:
for _output in sm_graph.outputs():
ir_graph._add_output(_convert_name(_output.debugName()))
predecessor_node_outputs = [o for o in _output.node().outputs()]
if len(predecessor_node_outputs) == 1:
src_node_idx = None
else:
src_node_idx = predecessor_node_outputs.index(_output)
ir_graph.add_edge(head=(node_index[_output.node()], src_node_idx),
tail=(ir_graph.output_node, None))
else:
# here is an example that the ir_graph and node_index is empty
# graph(%self : __torch__.torchmodels.googlenet.GoogLeNet,
# %x.1 : Tensor): return (%x.1)
# add a noop_identity node to handle this situation
self.global_seq += 1
ni_node = ir_graph.add_node(build_full_name(module_name, 'noop_identity', self.global_seq), 'noop_identity')
ir_graph.add_edge(head=(ir_graph.input_node, 0), tail=(ni_node, None))
ir_graph.add_edge(head=(ni_node, None), tail=(ir_graph.output_node, None))
for _output in sm_graph.outputs():
node_index[_output.node()] = ni_node
return node_index
# add an edge from head to tail to handle this situation
ir_graph.add_edge(head=(ir_graph.input_node, 0), tail=(ir_graph.output_node, None))
def merge_aten_slices(self, ir_graph):
"""
......@@ -625,20 +622,8 @@ class GraphConverter:
ir_graph = Graph(model=ir_model, graph_id=self.global_graph_id, name=module_name, _internal=True)
# handle graph nodes
node_index = self.handle_graph_nodes(script_module, sm_graph, module,
self.handle_graph_nodes(script_module, sm_graph, module,
module_name, ir_model, ir_graph)
# handle graph outputs
for _output in sm_graph.outputs():
ir_graph._add_output(_convert_name(_output.debugName()))
predecessor_node_outputs = [o for o in _output.node().outputs()]
if len(predecessor_node_outputs) == 1:
src_node_idx = None
else:
src_node_idx = predecessor_node_outputs.index(_output)
ir_graph.add_edge(head=(node_index[_output.node()], src_node_idx),
tail=(ir_graph.output_node, None))
self.refine_graph(ir_graph)
ir_graph._register()
......@@ -690,7 +675,7 @@ class GraphConverterWithShape(GraphConverter):
Known issues
------------
1. `InputChoice` and `ValueChoice` not supported yet.
2. Currently random inputs are feeded while tracing layerchoice.
2. Currently random inputs are fed while tracing layerchoice.
If forward path of candidates depends on input data, then wrong path will be traced.
This will result in incomplete shape info.
"""
......
from typing import Any, Union, Optional, List
import torch
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, List, Optional, Union
import torch
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.environments import ClusterEnvironment
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.plugins import Plugin
from pytorch_lightning.plugins.environments import ClusterEnvironment
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from ....serializer import serialize_cls
......@@ -69,9 +70,8 @@ class BypassPlugin(TrainingTypePlugin):
# bypass device placement from pytorch lightning
pass
def setup(self, model: torch.nn.Module) -> torch.nn.Module:
self.model_to_device()
return self.model
def setup(self) -> None:
pass
@property
def is_global_zero(self) -> bool:
......@@ -100,8 +100,9 @@ def get_accelerator_connector(
deterministic: bool = False,
precision: int = 32,
amp_backend: str = 'native',
amp_level: str = 'O2',
plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None,
amp_level: Optional[str] = None,
plugins: Optional[Union[List[Union[TrainingTypePlugin, ClusterEnvironment, str]],
TrainingTypePlugin, ClusterEnvironment, str]] = None,
**other_trainier_kwargs) -> AcceleratorConnector:
gpu_ids = Trainer()._parse_devices(gpus, auto_select_gpus, tpu_cores)
return AcceleratorConnector(
......
......@@ -7,7 +7,7 @@ from typing import Dict, List, Optional, Union
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
import torchmetrics
from torch.utils.data import DataLoader
import nni
......@@ -19,7 +19,7 @@ from ....serializer import serialize_cls
@serialize_cls
class _MultiModelSupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
n_models: int = 0,
learning_rate: float = 0.001,
weight_decay: float = 0.,
......@@ -119,7 +119,7 @@ class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
Class for optimizer (not an instance). default: ``Adam``
"""
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
......@@ -180,7 +180,7 @@ class _RegressionModule(MultiModelSupervisedLearningModule):
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, {'mse': pl.metrics.MeanSquaredError},
super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
......
......@@ -9,6 +9,7 @@ from typing import Dict, NoReturn, Union, Optional, List, Type
import pytorch_lightning as pl
import torch.nn as nn
import torch.optim as optim
import torchmetrics
from torch.utils.data import DataLoader
import nni
......@@ -140,7 +141,7 @@ def _check_dataloader(dataloader):
### The following are some commonly used Lightning modules ###
class _SupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
......@@ -213,7 +214,7 @@ class _SupervisedLearningModule(LightningModule):
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}
class _AccuracyWithLogits(pl.metrics.Accuracy):
class _AccuracyWithLogits(torchmetrics.Accuracy):
def update(self, pred, target):
return super().update(nn.functional.softmax(pred), target)
......@@ -278,7 +279,7 @@ class _RegressionModule(_SupervisedLearningModule):
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'mse': pl.metrics.MeanSquaredError},
super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
export_onnx=export_onnx)
......
......@@ -219,7 +219,8 @@ class RetiariiExperiment(Experiment):
elif self.config.execution_engine == 'cgo':
from ..execution.cgo_engine import CGOExecutionEngine
# assert self.config.trial_gpu_number==1, "trial_gpu_number must be 1 to use CGOExecutionEngine"
assert self.config.training_service.platform == 'remote', \
"CGO execution engine currently only supports remote training service"
assert self.config.batch_waiting_time is not None
devices = self._construct_devices()
engine = CGOExecutionEngine(devices,
......@@ -273,11 +274,10 @@ class RetiariiExperiment(Experiment):
devices = []
if hasattr(self.config.training_service, 'machine_list'):
for machine in self.config.training_service.machine_list:
assert machine.gpu_indices is not None, \
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
for gpu_idx in machine.gpu_indices:
devices.append(GPUDevice(machine.host, gpu_idx))
else:
for gpu_idx in self.config.training_service.gpu_indices:
devices.append(GPUDevice('local', gpu_idx))
return devices
def _create_dispatcher(self):
......
......@@ -254,6 +254,13 @@ class AtenFloordiv(PyTorchOperation):
return f'{output} = {inputs[0]} // {inputs[1]}'
class AtenMul(PyTorchOperation):
_ori_type_name = ['aten::mul']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = {inputs[0]} * {inputs[1]}'
class AtenLen(PyTorchOperation):
_ori_type_name = ['aten::len']
......@@ -491,7 +498,8 @@ class AtenAvgpool2d(PyTorchOperation):
class ToDevice(PyTorchOperation):
_artificial_op_name = "ToDevice"
def __init__(self, type_name: str, parameters: Dict[str, Any], _internal: bool = False):
def __init__(self, type_name: str, parameters: Dict[str, Any], _internal: bool = False,
attributes: Dict[str, Any] = None):
self.type = "ToDevice"
self.device = parameters['device']
self.overridden_device_repr = None
......
......@@ -57,6 +57,8 @@ def parse_path(experiment_config, config_path):
expand_path(experiment_config['assessor'], 'codeDir')
if experiment_config.get('advisor'):
expand_path(experiment_config['advisor'], 'codeDir')
if experiment_config['advisor'].get('classArgs') and experiment_config['advisor']['classArgs'].get('config_space'):
expand_path(experiment_config['advisor']['classArgs'], 'config_space')
if experiment_config.get('machineList'):
for index in range(len(experiment_config['machineList'])):
expand_path(experiment_config['machineList'][index], 'sshKeyPath')
......@@ -95,8 +97,8 @@ def parse_path(experiment_config, config_path):
if experiment_config.get('advisor'):
parse_relative_path(root_path, experiment_config['advisor'], 'codeDir')
# for BOHB when delivering a ConfigSpace file directly
if experiment_config.get('advisor').get('classArgs') and experiment_config.get('advisor').get('classArgs').get('config_space'):
parse_relative_path(root_path, experiment_config.get('advisor').get('classArgs'), 'config_space')
if experiment_config['advisor'].get('classArgs') and experiment_config['advisor']['classArgs'].get('config_space'):
parse_relative_path(root_path, experiment_config['advisor']['classArgs'], 'config_space')
if experiment_config.get('machineList'):
for index in range(len(experiment_config['machineList'])):
......
......@@ -97,10 +97,10 @@ class QuantizationSpeedupTestCase(TestCase):
model = BackboneModel()
configure_list = {
'conv1':{'weight_bit':8, 'activation_bit':8},
'conv2':{'weight_bit':32, 'activation_bit':32},
'fc1':{'weight_bit':16, 'activation_bit':16},
'fc2':{'weight_bit':8, 'activation_bit':8}
'conv1':{'weight_bits':8, 'output_bits':8},
'conv2':{'weight_bits':32, 'output_bits':32},
'fc1':{'weight_bits':16, 'output_bits':16},
'fc2':{'weight_bits':8, 'output_bits':8}
}
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
......@@ -126,16 +126,16 @@ class QuantizationSpeedupTestCase(TestCase):
model = BackboneModel()
configure_list = [{
'quant_types': ['weight', 'output'],
'quant_bits': {'weight':8, 'output':8},
'quant_types': ['input', 'weight'],
'quant_bits': {'input':8, 'weight':8},
'op_names': ['conv1']
}, {
'quant_types': ['output'],
'quant_bits': {'output':8},
'op_names': ['relu1']
}, {
'quant_types': ['weight', 'output'],
'quant_bits': {'weight':8, 'output':8},
'quant_types': ['input', 'weight'],
'quant_bits': {'input':8, 'weight':8},
'op_names': ['conv2']
}, {
'quant_types': ['output'],
......@@ -145,8 +145,9 @@ class QuantizationSpeedupTestCase(TestCase):
]
# finetune the model by using QAT
dummy_input = torch.randn(1, 1, 28, 28)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = QAT_Quantizer(model, configure_list, optimizer)
quantizer = QAT_Quantizer(model, configure_list, optimizer, dummy_input)
quantizer.compress()
model.to(self.device)
......@@ -178,13 +179,13 @@ class QuantizationSpeedupTestCase(TestCase):
model = vgg16()
configure_list = {
'features.0':{'weight_bit':8, 'activation_bit':8},
'features.1':{'weight_bit':32, 'activation_bit':32},
'features.2':{'weight_bit':16, 'activation_bit':16},
'features.4':{'weight_bit':8, 'activation_bit':8},
'features.7':{'weight_bit':8, 'activation_bit':8},
'features.8':{'weight_bit':8, 'activation_bit':8},
'features.11':{'weight_bit':8, 'activation_bit':8}
'features.0':{'weight_bits':8, 'output_bits':8},
'features.1':{'weight_bits':32, 'output_bits':32},
'features.2':{'weight_bits':16, 'output_bits':16},
'features.4':{'weight_bits':8, 'output_bits':8},
'features.7':{'weight_bits':8, 'output_bits':8},
'features.8':{'weight_bits':8, 'output_bits':8},
'features.11':{'weight_bits':8, 'output_bits':8}
}
model.to(self.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment