Unverified Commit 24fa4619 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

Merge pull request #2081 from microsoft/v1.4

merge V1.4 back to master
parents aaaa2756 8ff039c2
...@@ -25,7 +25,7 @@ The tool manages automated machine learning (AutoML) experiments, **dispatches a ...@@ -25,7 +25,7 @@ The tool manages automated machine learning (AutoML) experiments, **dispatches a
* Researchers and data scientists who want to easily **implement and experiement new AutoML algorithms**, may it be: hyperparameter tuning algorithm, neural architect search algorithm or model compression algorithm. * Researchers and data scientists who want to easily **implement and experiement new AutoML algorithms**, may it be: hyperparameter tuning algorithm, neural architect search algorithm or model compression algorithm.
* ML Platform owners who want to **support AutoML in their platform**. * ML Platform owners who want to **support AutoML in their platform**.
### **NNI v1.3 has been released! &nbsp;<a href="#nni-released-reminder"><img width="48" src="docs/img/release_icon.png"></a>** ### **NNI v1.4 has been released! &nbsp;<a href="#nni-released-reminder"><img width="48" src="docs/img/release_icon.png"></a>**
## **NNI capabilities in a glance** ## **NNI capabilities in a glance**
NNI provides CommandLine Tool as well as an user friendly WebUI to manage training experiements. With the extensible API, you can customize your own AutoML algorithms and training services. To make it easy for new users, NNI also provides a set of build-in stat-of-the-art AutoML algorithms and out of box support for popular training platforms. NNI provides CommandLine Tool as well as an user friendly WebUI to manage training experiements. With the extensible API, you can customize your own AutoML algorithms and training services. To make it easy for new users, NNI also provides a set of build-in stat-of-the-art AutoML algorithms and out of box support for popular training platforms.
...@@ -233,7 +233,7 @@ The following example is built on TensorFlow 1.x. Make sure **TensorFlow 1.x is ...@@ -233,7 +233,7 @@ The following example is built on TensorFlow 1.x. Make sure **TensorFlow 1.x is
* Download the examples via clone the source code. * Download the examples via clone the source code.
```bash ```bash
git clone -b v1.3 https://github.com/Microsoft/nni.git git clone -b v1.4 https://github.com/Microsoft/nni.git
``` ```
* Run the MNIST example. * Run the MNIST example.
......
...@@ -60,4 +60,4 @@ ProxylessNasMutator also implements the forward logic of the mutables (i.e., Lay ...@@ -60,4 +60,4 @@ ProxylessNasMutator also implements the forward logic of the mutables (i.e., Lay
## Reproduce Results ## Reproduce Results
Ongoing... To reproduce the result, we first run the search, we found that though it runs many epochs the chosen architecture converges at the first several epochs. This is probably induced by hyper-parameters or the implementation, we are working on it. The test accuracy of the found architecture is top1: 72.31, top5: 90.26.
# ChangeLog # ChangeLog
## Release 1.4 - 2/19/2020
### Major Features
#### Neural Architecture Search
* Support [C-DARTS](https://github.com/microsoft/nni/blob/v1.4/docs/en_US/NAS/CDARTS.md) algorithm and add [the example](https://github.com/microsoft/nni/tree/v1.4/examples/nas/cdarts) using it
* Support a preliminary version of [ProxylessNAS](https://github.com/microsoft/nni/blob/v1.4/docs/en_US/NAS/Proxylessnas.md) and the corresponding [example](https://github.com/microsoft/nni/tree/v1.4/examples/nas/proxylessnas)
* Add unit tests for the NAS framework
#### Model Compression
* Support DataParallel for compressing models, and provide [an example](https://github.com/microsoft/nni/blob/v1.4/examples/model_compress/multi_gpu.py) of using DataParallel
* Support [model speedup](https://github.com/microsoft/nni/blob/v1.4/docs/en_US/Compressor/ModelSpeedup.md) for compressed models, in Alpha version
#### Training Service
* Support complete PAI configurations by allowing users to specify PAI config file path
* Add example config yaml files for the new PAI mode (i.e., paiK8S)
* Support deleting experiments using sshkey in remote mode (thanks external contributor @tyusr)
#### WebUI
* WebUI refactor: adopt fabric framework
#### Others
* Support running [NNI experiment at foreground](https://github.com/microsoft/nni/blob/v1.4/docs/en_US/Tutorial/Nnictl.md#manage-an-experiment), i.e., `--foreground` argument in `nnictl create/resume/view`
* Support canceling the trials in UNKNOWN state
* Support large search space whose size could be up to 50mb (thanks external contributor @Sundrops)
### Documentation
* Improve [the index structure](https://nni.readthedocs.io/en/latest/) of NNI readthedocs
* Improve [documentation for NAS](https://github.com/microsoft/nni/blob/v1.4/docs/en_US/NAS/NasGuide.md)
* Improve documentation for [the new PAI mode](https://github.com/microsoft/nni/blob/v1.4/docs/en_US/TrainingService/PaiMode.md)
* Add QuickStart guidance for [NAS](https://github.com/microsoft/nni/blob/v1.4/docs/en_US/NAS/QuickStart.md) and [model compression](https://github.com/microsoft/nni/blob/v1.4/docs/en_US/Compressor/QuickStart.md)
* Improve documentation for [the supported EfficientNet](https://github.com/microsoft/nni/blob/v1.4/docs/en_US/TrialExample/EfficientNet.md)
### Bug Fixes
* Correctly support NaN in metric data, JSON compliant
* Fix the out-of-range bug of `randint` type in search space
* Fix the bug of wrong tensor device when exporting onnx model in model compression
* Fix incorrect handling of nnimanagerIP in the new PAI mode (i.e., paiK8S)
## Release 1.3 - 12/30/2019 ## Release 1.3 - 12/30/2019
### Major Features ### Major Features
......
...@@ -19,7 +19,7 @@ Installation on Linux and macOS follow the same instruction below. ...@@ -19,7 +19,7 @@ Installation on Linux and macOS follow the same instruction below.
Prerequisites: `python 64-bit >=3.5`, `git`, `wget` Prerequisites: `python 64-bit >=3.5`, `git`, `wget`
```bash ```bash
git clone -b v1.3 https://github.com/Microsoft/nni.git git clone -b v1.4 https://github.com/Microsoft/nni.git
cd nni cd nni
./install.sh ./install.sh
``` ```
...@@ -35,7 +35,7 @@ The following example is built on TensorFlow 1.x. Make sure **TensorFlow 1.x is ...@@ -35,7 +35,7 @@ The following example is built on TensorFlow 1.x. Make sure **TensorFlow 1.x is
* Download the examples via clone the source code. * Download the examples via clone the source code.
```bash ```bash
git clone -b v1.3 https://github.com/Microsoft/nni.git git clone -b v1.4 https://github.com/Microsoft/nni.git
``` ```
* Run the MNIST example. * Run the MNIST example.
......
...@@ -19,7 +19,7 @@ Anaconda or Miniconda is highly recommended to manage multiple Python environmen ...@@ -19,7 +19,7 @@ Anaconda or Miniconda is highly recommended to manage multiple Python environmen
Prerequisites: `python 64-bit >=3.5`, `git`, `PowerShell`. Prerequisites: `python 64-bit >=3.5`, `git`, `PowerShell`.
```bash ```bash
git clone -b v1.3 https://github.com/Microsoft/nni.git git clone -b v1.4 https://github.com/Microsoft/nni.git
cd nni cd nni
powershell -ExecutionPolicy Bypass -file install.ps1 powershell -ExecutionPolicy Bypass -file install.ps1
``` ```
...@@ -31,7 +31,7 @@ The following example is built on TensorFlow 1.x. Make sure **TensorFlow 1.x is ...@@ -31,7 +31,7 @@ The following example is built on TensorFlow 1.x. Make sure **TensorFlow 1.x is
* Download the examples via clone the source code. * Download the examples via clone the source code.
```bash ```bash
git clone -b v1.3 https://github.com/Microsoft/nni.git git clone -b v1.4 https://github.com/Microsoft/nni.git
``` ```
* Run the MNIST example. * Run the MNIST example.
......
...@@ -28,7 +28,7 @@ author = 'Microsoft' ...@@ -28,7 +28,7 @@ author = 'Microsoft'
# The short X.Y version # The short X.Y version
version = '' version = ''
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
release = 'v1.3' release = 'v1.4'
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
......
...@@ -55,7 +55,7 @@ def test(model, device, test_loader): ...@@ -55,7 +55,7 @@ def test(model, device, test_loader):
def main(): def main():
torch.manual_seed(0) torch.manual_seed(0)
device = torch.device('cpu') device = torch.device('cuda')
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
...@@ -66,7 +66,7 @@ def main(): ...@@ -66,7 +66,7 @@ def main():
batch_size=1000, shuffle=True) batch_size=1000, shuffle=True)
model = Mnist() model = Mnist()
model.to(device) model = model.to(device)
'''you can change this to LevelPruner to implement it '''you can change this to LevelPruner to implement it
pruner = LevelPruner(configure_list) pruner = LevelPruner(configure_list)
...@@ -82,14 +82,14 @@ def main(): ...@@ -82,14 +82,14 @@ def main():
pruner = AGP_Pruner(model, configure_list) pruner = AGP_Pruner(model, configure_list)
model = pruner.compress() model = pruner.compress()
model = model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
for epoch in range(10): for epoch in range(10):
pruner.update_epoch(epoch) pruner.update_epoch(epoch)
print('# Epoch {} #'.format(epoch)) print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer) train(model, device, train_loader, optimizer)
test(model, device, test_loader) test(model, device, test_loader)
pruner.export_model('model.pth', 'mask.pth', 'model.onnx', [1, 1, 28, 28]) pruner.export_model('model.pth', 'mask.pth', 'model.onnx', [1, 1, 28, 28], device)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -34,6 +34,7 @@ if __name__ == "__main__": ...@@ -34,6 +34,7 @@ if __name__ == "__main__":
# configurations for search # configurations for search
parser.add_argument("--checkpoint_path", default='./search_mobile_net.pt', type=str) parser.add_argument("--checkpoint_path", default='./search_mobile_net.pt', type=str)
parser.add_argument("--arch_path", default='./arch_path.pt', type=str) parser.add_argument("--arch_path", default='./arch_path.pt', type=str)
parser.add_argument("--no-warmup", dest='warmup', action='store_false')
# configurations for retrain # configurations for retrain
parser.add_argument("--exported_arch_path", default=None, type=str) parser.add_argument("--exported_arch_path", default=None, type=str)
...@@ -54,7 +55,7 @@ if __name__ == "__main__": ...@@ -54,7 +55,7 @@ if __name__ == "__main__":
# move network to GPU if available # move network to GPU if available
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device('cuda:0') device = torch.device('cuda')
else: else:
device = torch.device('cpu') device = torch.device('cpu')
...@@ -86,7 +87,7 @@ if __name__ == "__main__": ...@@ -86,7 +87,7 @@ if __name__ == "__main__":
train_loader=data_provider.train, train_loader=data_provider.train,
valid_loader=data_provider.valid, valid_loader=data_provider.valid,
device=device, device=device,
warmup=True, warmup=args.warmup,
ckpt_path=args.checkpoint_path, ckpt_path=args.checkpoint_path,
arch_path=args.arch_path) arch_path=args.arch_path)
...@@ -102,4 +103,4 @@ if __name__ == "__main__": ...@@ -102,4 +103,4 @@ if __name__ == "__main__":
"exported_arch_path {} should be a file.".format(args.exported_arch_path) "exported_arch_path {} should be a file.".format(args.exported_arch_path)
apply_fixed_architecture(model, args.exported_arch_path, device=device) apply_fixed_architecture(model, args.exported_arch_path, device=device)
trainer = Retrain(model, optimizer, device, data_provider, n_epochs=300) trainer = Retrain(model, optimizer, device, data_provider, n_epochs=300)
trainer.run() trainer.run()
\ No newline at end of file
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
'use strict'; 'use strict';
import * as assert from 'assert'; import * as assert from 'assert';
import * as JSON5 from 'json5';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import * as component from '../common/component'; import * as component from '../common/component';
...@@ -132,7 +131,7 @@ class NNIDataStore implements DataStore { ...@@ -132,7 +131,7 @@ class NNIDataStore implements DataStore {
} }
public async storeMetricData(trialJobId: string, data: string): Promise<void> { public async storeMetricData(trialJobId: string, data: string): Promise<void> {
const metrics: MetricData = JSON5.parse(data); const metrics: MetricData = JSON.parse(data);
// REQUEST_PARAMETER is used to request new parameters for multiphase trial job, // REQUEST_PARAMETER is used to request new parameters for multiphase trial job,
// it is not metrics, so it is skipped here. // it is not metrics, so it is skipped here.
if (metrics.type === 'REQUEST_PARAMETER') { if (metrics.type === 'REQUEST_PARAMETER') {
...@@ -141,7 +140,7 @@ class NNIDataStore implements DataStore { ...@@ -141,7 +140,7 @@ class NNIDataStore implements DataStore {
} }
assert(trialJobId === metrics.trial_job_id); assert(trialJobId === metrics.trial_job_id);
try { try {
await this.db.storeMetricData(trialJobId, JSON5.stringify({ await this.db.storeMetricData(trialJobId, JSON.stringify({
trialJobId: metrics.trial_job_id, trialJobId: metrics.trial_job_id,
parameterId: metrics.parameter_id, parameterId: metrics.parameter_id,
type: metrics.type, type: metrics.type,
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import * as assert from 'assert'; import * as assert from 'assert';
import * as fs from 'fs'; import * as fs from 'fs';
import * as JSON5 from 'json5';
import * as path from 'path'; import * as path from 'path';
import * as sqlite3 from 'sqlite3'; import * as sqlite3 from 'sqlite3';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
...@@ -203,10 +202,10 @@ class SqlDB implements Database { ...@@ -203,10 +202,10 @@ class SqlDB implements Database {
public storeMetricData(trialJobId: string, data: string): Promise<void> { public storeMetricData(trialJobId: string, data: string): Promise<void> {
const sql: string = 'insert into MetricData values (?,?,?,?,?,?)'; const sql: string = 'insert into MetricData values (?,?,?,?,?,?)';
const json: MetricDataRecord = JSON5.parse(data); const json: MetricDataRecord = JSON.parse(data);
const args: any[] = [Date.now(), json.trialJobId, json.parameterId, json.type, json.sequence, JSON5.stringify(json.data)]; const args: any[] = [Date.now(), json.trialJobId, json.parameterId, json.type, json.sequence, JSON.stringify(json.data)];
this.log.trace(`storeMetricData: SQL: ${sql}, args: ${JSON5.stringify(args)}`); this.log.trace(`storeMetricData: SQL: ${sql}, args: ${JSON.stringify(args)}`);
const deferred: Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
this.db.run(sql, args, (err: Error | null) => { this.resolve(deferred, err); }); this.db.run(sql, args, (err: Error | null) => { this.resolve(deferred, err); });
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
"express": "^4.16.3", "express": "^4.16.3",
"express-joi-validator": "^2.0.0", "express-joi-validator": "^2.0.0",
"js-base64": "^2.4.9", "js-base64": "^2.4.9",
"json5": "^2.1.1",
"kubernetes-client": "^6.5.0", "kubernetes-client": "^6.5.0",
"rx": "^4.1.0", "rx": "^4.1.0",
"sqlite3": "^4.0.2", "sqlite3": "^4.0.2",
...@@ -36,7 +35,6 @@ ...@@ -36,7 +35,6 @@
"@types/express": "^4.16.0", "@types/express": "^4.16.0",
"@types/glob": "^7.1.1", "@types/glob": "^7.1.1",
"@types/js-base64": "^2.3.1", "@types/js-base64": "^2.3.1",
"@types/json5": "^0.0.30",
"@types/mocha": "^5.2.5", "@types/mocha": "^5.2.5",
"@types/node": "10.12.18", "@types/node": "10.12.18",
"@types/request": "^2.47.1", "@types/request": "^2.47.1",
......
...@@ -34,7 +34,7 @@ export class NNIRestServer extends RestServer { ...@@ -34,7 +34,7 @@ export class NNIRestServer extends RestServer {
*/ */
protected registerRestHandler(): void { protected registerRestHandler(): void {
this.app.use(express.static('static')); this.app.use(express.static('static'));
this.app.use(bodyParser.json()); this.app.use(bodyParser.json({limit: '50mb'}));
this.app.use(this.API_ROOT_URL, createRestHandler(this)); this.app.use(this.API_ROOT_URL, createRestHandler(this));
this.app.use(this.LOGS_ROOT_URL, express.static(getLogDir())); this.app.use(this.LOGS_ROOT_URL, express.static(getLogDir()));
this.app.get('*', (req: express.Request, res: express.Response) => { this.app.get('*', (req: express.Request, res: express.Response) => {
......
...@@ -157,10 +157,6 @@ ...@@ -157,10 +157,6 @@
version "7.0.3" version "7.0.3"
resolved "https://registry.yarnpkg.com/@types/json-schema/-/json-schema-7.0.3.tgz#bdfd69d61e464dcc81b25159c270d75a73c1a636" resolved "https://registry.yarnpkg.com/@types/json-schema/-/json-schema-7.0.3.tgz#bdfd69d61e464dcc81b25159c270d75a73c1a636"
"@types/json5@^0.0.30":
version "0.0.30"
resolved "https://registry.yarnpkg.com/@types/json5/-/json5-0.0.30.tgz#44cb52f32a809734ca562e685c6473b5754a7818"
"@types/mime@*": "@types/mime@*":
version "2.0.0" version "2.0.0"
resolved "https://registry.yarnpkg.com/@types/mime/-/mime-2.0.0.tgz#5a7306e367c539b9f6543499de8dd519fac37a8b" resolved "https://registry.yarnpkg.com/@types/mime/-/mime-2.0.0.tgz#5a7306e367c539b9f6543499de8dd519fac37a8b"
...@@ -2380,12 +2376,6 @@ json-stringify-safe@~5.0.1: ...@@ -2380,12 +2376,6 @@ json-stringify-safe@~5.0.1:
version "5.0.1" version "5.0.1"
resolved "https://registry.yarnpkg.com/json-stringify-safe/-/json-stringify-safe-5.0.1.tgz#1296a2d58fd45f19a0f6ce01d65701e2c735b6eb" resolved "https://registry.yarnpkg.com/json-stringify-safe/-/json-stringify-safe-5.0.1.tgz#1296a2d58fd45f19a0f6ce01d65701e2c735b6eb"
json5@^2.1.1:
version "2.1.1"
resolved "https://registry.yarnpkg.com/json5/-/json5-2.1.1.tgz#81b6cb04e9ba496f1c7005d07b4368a2638f90b6"
dependencies:
minimist "^1.2.0"
jsonparse@^1.2.0: jsonparse@^1.2.0:
version "1.3.1" version "1.3.1"
resolved "https://registry.yarnpkg.com/jsonparse/-/jsonparse-1.3.1.tgz#3f4dae4a91fac315f71062f8521cc239f1366280" resolved "https://registry.yarnpkg.com/jsonparse/-/jsonparse-1.3.1.tgz#3f4dae4a91fac315f71062f8521cc239f1366280"
......
...@@ -557,7 +557,8 @@ class BOHB(MsgDispatcherBase): ...@@ -557,7 +557,8 @@ class BOHB(MsgDispatcherBase):
Data type not supported Data type not supported
""" """
logger.debug('handle report metric data = %s', data) logger.debug('handle report metric data = %s', data)
if 'value' in data:
data['value'] = json_tricks.loads(data['value'])
if data['type'] == MetricType.REQUEST_PARAMETER: if data['type'] == MetricType.REQUEST_PARAMETER:
assert multi_phase_enabled() assert multi_phase_enabled()
assert data['trial_job_id'] is not None assert data['trial_job_id'] is not None
...@@ -627,6 +628,8 @@ class BOHB(MsgDispatcherBase): ...@@ -627,6 +628,8 @@ class BOHB(MsgDispatcherBase):
AssertionError AssertionError
data doesn't have required key 'parameter' and 'value' data doesn't have required key 'parameter' and 'value'
""" """
for entry in data:
entry['value'] = json_tricks.loads(entry['value'])
_completed_num = 0 _completed_num = 0
for trial_info in data: for trial_info in data:
logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data)) logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data))
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import logging
import torch import torch
from .infer_shape import CoarseMask, ModuleMasks from .infer_shape import ModuleMasks
_logger = logging.getLogger(__name__)
replace_module = { replace_module = {
'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask), 'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask),
'Conv2d': lambda module, mask: replace_conv2d(module, mask), 'Conv2d': lambda module, mask: replace_conv2d(module, mask),
'MaxPool2d': lambda module, mask: no_replace(module, mask), 'MaxPool2d': lambda module, mask: no_replace(module, mask),
'AvgPool2d': lambda module, mask: no_replace(module, mask),
'ReLU': lambda module, mask: no_replace(module, mask), 'ReLU': lambda module, mask: no_replace(module, mask),
'Linear': lambda module, mask: replace_linear(module, mask) 'Linear': lambda module, mask: replace_linear(module, mask)
} }
...@@ -16,6 +20,7 @@ def no_replace(module, mask): ...@@ -16,6 +20,7 @@ def no_replace(module, mask):
""" """
No need to replace No need to replace
""" """
_logger.debug("no need to replace")
return module return module
def replace_linear(linear, mask): def replace_linear(linear, mask):
...@@ -37,9 +42,8 @@ def replace_linear(linear, mask): ...@@ -37,9 +42,8 @@ def replace_linear(linear, mask):
assert mask.output_mask is None assert mask.output_mask is None
assert not mask.param_masks assert not mask.param_masks
index = mask.input_mask.mask_index[-1] index = mask.input_mask.mask_index[-1]
print(mask.input_mask.mask_index)
in_features = index.size()[0] in_features = index.size()[0]
print('linear: ', in_features) _logger.debug("replace linear with new in_features: %d", in_features)
new_linear = torch.nn.Linear(in_features=in_features, new_linear = torch.nn.Linear(in_features=in_features,
out_features=linear.out_features, out_features=linear.out_features,
bias=linear.bias is not None) bias=linear.bias is not None)
...@@ -67,7 +71,7 @@ def replace_batchnorm2d(norm, mask): ...@@ -67,7 +71,7 @@ def replace_batchnorm2d(norm, mask):
assert 'weight' in mask.param_masks and 'bias' in mask.param_masks assert 'weight' in mask.param_masks and 'bias' in mask.param_masks
index = mask.param_masks['weight'].mask_index[0] index = mask.param_masks['weight'].mask_index[0]
num_features = index.size()[0] num_features = index.size()[0]
print("replace batchnorm2d: ", num_features, index) _logger.debug("replace batchnorm2d with num_features: %d", num_features)
new_norm = torch.nn.BatchNorm2d(num_features=num_features, new_norm = torch.nn.BatchNorm2d(num_features=num_features,
eps=norm.eps, eps=norm.eps,
momentum=norm.momentum, momentum=norm.momentum,
...@@ -106,6 +110,7 @@ def replace_conv2d(conv, mask): ...@@ -106,6 +110,7 @@ def replace_conv2d(conv, mask):
else: else:
out_channels_index = mask.output_mask.mask_index[1] out_channels_index = mask.output_mask.mask_index[1]
out_channels = out_channels_index.size()[0] out_channels = out_channels_index.size()[0]
_logger.debug("replace conv2d with in_channels: %d, out_channels: %d", in_channels, out_channels)
new_conv = torch.nn.Conv2d(in_channels=in_channels, new_conv = torch.nn.Conv2d(in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=conv.kernel_size, kernel_size=conv.kernel_size,
...@@ -128,6 +133,5 @@ def replace_conv2d(conv, mask): ...@@ -128,6 +133,5 @@ def replace_conv2d(conv, mask):
assert tmp_weight_data is not None, "Conv2d weight should be updated based on masks" assert tmp_weight_data is not None, "Conv2d weight should be updated based on masks"
new_conv.weight.data.copy_(tmp_weight_data) new_conv.weight.data.copy_(tmp_weight_data)
if conv.bias is not None: if conv.bias is not None:
print('final conv.bias is not None')
new_conv.bias.data.copy_(conv.bias.data if tmp_bias_data is None else tmp_bias_data) new_conv.bias.data.copy_(conv.bias.data if tmp_bias_data is None else tmp_bias_data)
return new_conv return new_conv
...@@ -158,7 +158,7 @@ class ModelSpeedup: ...@@ -158,7 +158,7 @@ class ModelSpeedup:
""" """
# TODO: scope name could be empty # TODO: scope name could be empty
node_name = '.'.join([node.scopeName(), node.kind(), str(self.global_count)]) node_name = '.'.join([node.scopeName(), node.kind(), str(self.global_count)])
#print('node_name: ', node_name) _logger.debug("expand non-prim node, node name: %s", node_name)
self.global_count += 1 self.global_count += 1
op_type = node.kind() op_type = node.kind()
...@@ -173,7 +173,6 @@ class ModelSpeedup: ...@@ -173,7 +173,6 @@ class ModelSpeedup:
input_name = _input.debugName() input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes: if input_name in output_to_node and output_to_node[input_name] in nodes:
predecessor_node = output_to_node[input_name] predecessor_node = output_to_node[input_name]
#print("predecessor_node: ", predecessor_node)
if predecessor_node.kind().startswith('prim::'): if predecessor_node.kind().startswith('prim::'):
node_group.append(predecessor_node) node_group.append(predecessor_node)
node_queue.put(predecessor_node) node_queue.put(predecessor_node)
...@@ -211,6 +210,60 @@ class ModelSpeedup: ...@@ -211,6 +210,60 @@ class ModelSpeedup:
out_shape = t_output.type().sizes() out_shape = t_output.type().sizes()
return {'in_shape': in_shape, 'out_shape': out_shape} return {'in_shape': in_shape, 'out_shape': out_shape}
def _extract_leaf_modules(self, graph):
"""
Extract leaf modules from the given graph. Leaf module means it does not have submodules.
To extract leaf modules because only leaf module can be replaced. And shape inference can
be done in leaf module level. Other shape inference is done in lower level i.e.,
operation level.
Parameters
----------
graph : jit trace graph
the graph generated from jit trace
Returns
-------
list
a list of scope name of all the leaf modules
"""
pieces = [] # each element is a dict
for node in graph.nodes():
scope_name = node.scopeName()
if scope_name == '':
continue
segs = scope_name.split('/')
segs_len = len(segs)
# increase the length of `pieces` if not enough
for _ in range(segs_len - len(pieces)):
pieces.append({})
# process internal segments of the scope name
# 'L' means leaf segment
# 'I' means internal segment
# internal segment can replace leaf segment at the same position of `pieces`
for i, seg in enumerate(segs[:-1]):
seg_name_dict = pieces[i]
if seg in seg_name_dict:
if seg_name_dict[seg][0] == 'L':
seg_name_dict[seg] = ('I', node)
else:
seg_name_dict[seg] = ('I', node)
# process the leaf segment of the scope name
last_segs_dict = pieces[len(segs) - 1]
if not segs[-1] in last_segs_dict:
last_segs_dict[segs[-1]] = ('L', node)
# traverse `pieces` to obtain all the leaf modules which are labeled with 'L'
leaf_modules = []
for piece in pieces:
for _, value in piece.items():
if value[0] == 'L':
assert value[1].scopeName() not in leaf_modules
# if this is a leaf module, the last segment of its scope name
# must be in pattern `xxx[xxx]`
if value[1].scopeName()[-1] == ']':
leaf_modules.append(value[1].scopeName())
return leaf_modules
def _build_graph(self): def _build_graph(self):
""" """
Build graph using our defined format from jit trace. Build graph using our defined format from jit trace.
...@@ -231,7 +284,7 @@ class ModelSpeedup: ...@@ -231,7 +284,7 @@ class ModelSpeedup:
""" """
graph = self.trace_graph.graph graph = self.trace_graph.graph
# if torch 1.4.0 is used, consider run torch._C._jit_pass_inline(graph) here # if torch 1.4.0 is used, consider run torch._C._jit_pass_inline(graph) here
#print(graph) _logger.debug(graph)
# build output mapping, from output debugName to its node # build output mapping, from output debugName to its node
output_to_node = dict() output_to_node = dict()
# build input mapping, from input debugName to its node # build input mapping, from input debugName to its node
...@@ -250,6 +303,9 @@ class ModelSpeedup: ...@@ -250,6 +303,9 @@ class ModelSpeedup:
for output in graph.outputs(): for output in graph.outputs():
graph_outputs.append(output.debugName()) graph_outputs.append(output.debugName())
leaf_modules = self._extract_leaf_modules(graph)
_logger.debug(leaf_modules)
for node in graph.nodes(): for node in graph.nodes():
# populate output_to_node and input_to_node # populate output_to_node and input_to_node
for output in node.outputs(): for output in node.outputs():
...@@ -259,10 +315,8 @@ class ModelSpeedup: ...@@ -259,10 +315,8 @@ class ModelSpeedup:
input_name = _input.debugName() input_name = _input.debugName()
input_to_node[input_name] = node input_to_node[input_name] = node
scope_name = node.scopeName() # example: scope_name, 'MyCell/Linear[linear]' scope_name = node.scopeName() # example: scope_name, 'MyCell/Linear[linear]'
module_name_slices = re.findall(r'\[(.*?)\]', scope_name)
module_name = '.'.join(module_name_slices)
# if module_name is empty, it is not a module # if module_name is empty, it is not a module
if module_name == '': if not scope_name in leaf_modules:
if scope_name == '': if scope_name == '':
continue continue
else: else:
...@@ -271,6 +325,8 @@ class ModelSpeedup: ...@@ -271,6 +325,8 @@ class ModelSpeedup:
else: else:
func_to_nodes[scope_name] = [node] func_to_nodes[scope_name] = [node]
else: else:
module_name_slices = re.findall(r'\[(.*?)\]', scope_name)
module_name = '.'.join(module_name_slices)
scope_slice = scope_name.split('/')[-1] scope_slice = scope_name.split('/')[-1]
module_type = scope_slice.split('[')[0] module_type = scope_slice.split('[')[0]
module_to_type[module_name] = module_type module_to_type[module_name] = module_type
...@@ -301,10 +357,8 @@ class ModelSpeedup: ...@@ -301,10 +357,8 @@ class ModelSpeedup:
m_inputs.append(_input) m_inputs.append(_input)
elif not output_to_node[_input] in nodes: elif not output_to_node[_input] in nodes:
m_inputs.append(_input) m_inputs.append(_input)
print("module node_name: ", module_name)
if module_name == '': if module_name == '':
for n in nodes: _logger.warning("module_name is empty string")
print(n)
g_node = GNode(module_name, 'module', module_to_type[module_name], m_inputs, m_outputs, nodes) g_node = GNode(module_name, 'module', module_to_type[module_name], m_inputs, m_outputs, nodes)
self.g_nodes.append(g_node) self.g_nodes.append(g_node)
...@@ -345,10 +399,7 @@ class ModelSpeedup: ...@@ -345,10 +399,7 @@ class ModelSpeedup:
predecessors = [] predecessors = []
for _input in self.name_to_gnode[module_name].inputs: for _input in self.name_to_gnode[module_name].inputs:
if not _input in self.output_to_gnode: if not _input in self.output_to_gnode:
print(_input) _logger.debug("cannot find gnode with %s as its output", _input)
if not _input in self.output_to_gnode:
# TODO: check _input which does not have node
print("output with no gnode: ", _input)
else: else:
g_node = self.output_to_gnode[_input] g_node = self.output_to_gnode[_input]
predecessors.append(g_node.name) predecessors.append(g_node.name)
...@@ -379,7 +430,7 @@ class ModelSpeedup: ...@@ -379,7 +430,7 @@ class ModelSpeedup:
def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=None): def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=None):
""" """
Infer input shape / output shape based on the module's weight mask / input shape / output shape. Infer input shape / output shape based on the module's weight mask / input shape / output shape.
For a module: For a module:
Infer its input and output shape from its weight mask Infer its input and output shape from its weight mask
Infer its output shape from its input shape Infer its output shape from its input shape
...@@ -407,18 +458,20 @@ class ModelSpeedup: ...@@ -407,18 +458,20 @@ class ModelSpeedup:
self.inferred_masks[module_name] = module_masks self.inferred_masks[module_name] = module_masks
m_type = self.name_to_gnode[module_name].op_type m_type = self.name_to_gnode[module_name].op_type
print("infer_module_mask: {}, module type: {}".format(module_name, m_type)) _logger.debug("infer mask of module %s with op_type %s", module_name, m_type)
if mask is not None: if mask is not None:
#print("mask is not None") _logger.debug("mask is not None")
if not m_type in infer_from_mask: if not m_type in infer_from_mask:
raise RuntimeError("Has not supported infering \ raise RuntimeError(
input/output shape from mask for module/function: `{}`".format(m_type)) "Has not supported infering input/output shape from mask for module/function: `{}`, {}"
.format(m_type, module_name))
input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask) input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask)
if in_shape is not None: if in_shape is not None:
#print("in_shape is not None") _logger.debug("in_shape is not None")
if not m_type in infer_from_inshape: if not m_type in infer_from_inshape:
raise RuntimeError("Has not supported infering \ raise RuntimeError(
output shape from input shape for module/function: `{}`".format(m_type)) "Has not supported infering output shape from input shape for module/function: `{}`, {}"
.format(m_type, module_name))
if m_type == 'aten::view': if m_type == 'aten::view':
output_cmask = infer_from_inshape[m_type](module_masks, output_cmask = infer_from_inshape[m_type](module_masks,
in_shape, in_shape,
...@@ -426,23 +479,20 @@ class ModelSpeedup: ...@@ -426,23 +479,20 @@ class ModelSpeedup:
else: else:
output_cmask = infer_from_inshape[m_type](module_masks, in_shape) output_cmask = infer_from_inshape[m_type](module_masks, in_shape)
if out_shape is not None: if out_shape is not None:
#print("out_shape is not None") _logger.debug("out_shape is not None")
if not m_type in infer_from_outshape: if not m_type in infer_from_outshape:
raise RuntimeError("Has not supported infering \ raise RuntimeError(
input shape from output shape for module/function: `{}`".format(m_type)) "Has not supported infering input shape from output shape for module/function: `{}`, {}"
.format(m_type, module_name))
input_cmask = infer_from_outshape[m_type](module_masks, out_shape) input_cmask = infer_from_outshape[m_type](module_masks, out_shape)
if input_cmask: if input_cmask:
#print("input_cmask is not None")
predecessors = self._find_predecessors(module_name) predecessors = self._find_predecessors(module_name)
for _module_name in predecessors: for _module_name in predecessors:
print("input_cmask, module_name: ", _module_name)
self.infer_module_mask(_module_name, out_shape=input_cmask) self.infer_module_mask(_module_name, out_shape=input_cmask)
if output_cmask: if output_cmask:
#print("output_cmask is not None")
successors = self._find_successors(module_name) successors = self._find_successors(module_name)
for _module_name in successors: for _module_name in successors:
print("output_cmask, module_name: ", _module_name)
self.infer_module_mask(_module_name, in_shape=output_cmask) self.infer_module_mask(_module_name, in_shape=output_cmask)
def infer_modules_masks(self): def infer_modules_masks(self):
...@@ -463,16 +513,19 @@ class ModelSpeedup: ...@@ -463,16 +513,19 @@ class ModelSpeedup:
""" """
for module_name in self.inferred_masks: for module_name in self.inferred_masks:
g_node = self.name_to_gnode[module_name] g_node = self.name_to_gnode[module_name]
print(module_name, g_node.op_type) _logger.debug("replace %s, in %s type, with op_type %s",
module_name, g_node.type, g_node.op_type)
if g_node.type == 'module': if g_node.type == 'module':
super_module, leaf_module = get_module_by_name(self.bound_model, module_name) super_module, leaf_module = get_module_by_name(self.bound_model, module_name)
m_type = g_node.op_type m_type = g_node.op_type
if not m_type in replace_module: if not m_type in replace_module:
raise RuntimeError("Has not supported replacing the module: `{}`".format(m_type)) raise RuntimeError("Has not supported replacing the module: `{}`".format(m_type))
_logger.info("replace module (name: %s, op_type: %s)", module_name, m_type)
compressed_module = replace_module[m_type](leaf_module, self.inferred_masks[module_name]) compressed_module = replace_module[m_type](leaf_module, self.inferred_masks[module_name])
setattr(super_module, module_name.split('.')[-1], compressed_module) setattr(super_module, module_name.split('.')[-1], compressed_module)
elif g_node.type == 'func': elif g_node.type == 'func':
print("Warning: Cannot replace func...") _logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type",
module_name, g_node.op_type)
else: else:
raise RuntimeError("Unsupported GNode type: {}".format(g_node.type)) raise RuntimeError("Unsupported GNode type: {}".format(g_node.type))
...@@ -482,10 +535,12 @@ class ModelSpeedup: ...@@ -482,10 +535,12 @@ class ModelSpeedup:
first, do mask/shape inference, first, do mask/shape inference,
second, replace modules second, replace modules
""" """
#print("start to compress") _logger.info("start to speed up the model")
_logger.info("infer module masks...")
self.infer_modules_masks() self.infer_modules_masks()
_logger.info("replace compressed modules...")
self.replace_compressed_modules() self.replace_compressed_modules()
#print("finished compressing") _logger.info("speedup done")
# resume the model mode to that before the model is speed up # resume the model mode to that before the model is speed up
if self.is_training: if self.is_training:
self.bound_model.train() self.bound_model.train()
......
...@@ -56,7 +56,7 @@ class CoarseMask: ...@@ -56,7 +56,7 @@ class CoarseMask:
s.add(num) s.add(num)
for num in index_b: for num in index_b:
s.add(num) s.add(num)
return torch.tensor(sorted(s)) return torch.tensor(sorted(s)) # pylint: disable=not-callable
def merge(self, cmask): def merge(self, cmask):
""" """
...@@ -98,7 +98,7 @@ class ModuleMasks: ...@@ -98,7 +98,7 @@ class ModuleMasks:
self.param_masks = dict() self.param_masks = dict()
self.input_mask = None self.input_mask = None
self.output_mask = None self.output_mask = None
def set_param_masks(self, name, mask): def set_param_masks(self, name, mask):
""" """
Parameters Parameters
...@@ -217,7 +217,7 @@ def view_inshape(module_masks, mask, shape): ...@@ -217,7 +217,7 @@ def view_inshape(module_masks, mask, shape):
TODO: consider replace tensor.view with nn.Flatten, because tensor.view is not TODO: consider replace tensor.view with nn.Flatten, because tensor.view is not
included in module, thus, cannot be replaced by our framework. included in module, thus, cannot be replaced by our framework.
Parameters Parameters
---------- ----------
module_masks : ModuleMasks module_masks : ModuleMasks
...@@ -250,7 +250,7 @@ def view_inshape(module_masks, mask, shape): ...@@ -250,7 +250,7 @@ def view_inshape(module_masks, mask, shape):
step_size = shape['in_shape'][2] * shape['in_shape'][3] step_size = shape['in_shape'][2] * shape['in_shape'][3]
for loc in mask.mask_index[1]: for loc in mask.mask_index[1]:
index.extend([loc * step_size + i for i in range(step_size)]) index.extend([loc * step_size + i for i in range(step_size)])
output_cmask.add_index_mask(dim=1, index=torch.tensor(index)) output_cmask.add_index_mask(dim=1, index=torch.tensor(index)) # pylint: disable=not-callable
module_masks.set_output_mask(output_cmask) module_masks.set_output_mask(output_cmask)
return output_cmask return output_cmask
...@@ -373,7 +373,6 @@ def conv2d_mask(module_masks, mask): ...@@ -373,7 +373,6 @@ def conv2d_mask(module_masks, mask):
""" """
assert 'weight' in mask assert 'weight' in mask
assert isinstance(mask['weight'], torch.Tensor) assert isinstance(mask['weight'], torch.Tensor)
cmask = None
weight_mask = mask['weight'] weight_mask = mask['weight']
shape = weight_mask.size() shape = weight_mask.size()
ones = torch.ones(shape[1:]).to(weight_mask.device) ones = torch.ones(shape[1:]).to(weight_mask.device)
...@@ -451,7 +450,7 @@ def conv2d_outshape(module_masks, mask): ...@@ -451,7 +450,7 @@ def conv2d_outshape(module_masks, mask):
The ModuleMasks instance of the conv2d The ModuleMasks instance of the conv2d
mask : CoarseMask mask : CoarseMask
The mask of its output tensor The mask of its output tensor
Returns Returns
------- -------
CoarseMask CoarseMask
......
...@@ -149,13 +149,24 @@ class Compressor: ...@@ -149,13 +149,24 @@ class Compressor:
ret = None ret = None
for config in self.config_list: for config in self.config_list:
config = config.copy() config = config.copy()
config['op_types'] = self._expand_config_op_types(config) # expand config if key `default` is in config['op_types']
if layer.type not in config['op_types']: if 'op_types' in config and 'default' in config['op_types']:
expanded_op_types = []
for op_type in config['op_types']:
if op_type == 'default':
expanded_op_types.extend(default_layers.weighted_modules)
else:
expanded_op_types.append(op_type)
config['op_types'] = expanded_op_types
# check if condition is satisified
if 'op_types' in config and layer.type not in config['op_types']:
continue continue
if config.get('op_names') and layer.name not in config['op_names']: if 'op_names' in config and layer.name not in config['op_names']:
continue continue
ret = config ret = config
if ret is None or ret.get('exclude'): if ret is None or 'exclude' in ret:
return None return None
return ret return ret
...@@ -188,16 +199,6 @@ class Compressor: ...@@ -188,16 +199,6 @@ class Compressor:
""" """
raise NotImplementedError() raise NotImplementedError()
def _expand_config_op_types(self, config):
if config is None:
return []
expanded_op_types = []
for op_type in config.get('op_types', []):
if op_type == 'default':
expanded_op_types.extend(default_layers.weighted_modules)
else:
expanded_op_types.append(op_type)
return expanded_op_types
class PrunerModuleWrapper(torch.nn.Module): class PrunerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, pruner): def __init__(self, module, module_name, module_type, config, pruner):
...@@ -225,23 +226,29 @@ class PrunerModuleWrapper(torch.nn.Module): ...@@ -225,23 +226,29 @@ class PrunerModuleWrapper(torch.nn.Module):
# config and pruner # config and pruner
self.config = config self.config = config
self.pruner = pruner self.pruner = pruner
self.registered_buffers = {} self.registered_buffers = []
# register buffer for mask # register buffer for mask
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape)) self.register_buffer("weight_mask", torch.ones(self.module.weight.shape))
self.registered_buffers['weight_mask'] = self.weight_mask
if hasattr(self.module, 'bias') and self.module.bias is not None: if hasattr(self.module, 'bias') and self.module.bias is not None:
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape)) self.register_buffer("bias_mask", torch.ones(self.module.bias.shape))
else: else:
self.register_buffer("bias_mask", None) self.register_buffer("bias_mask", None)
self.registered_buffers['bias_mask'] = self.bias_mask self.registered_buffers.append('weight_mask')
self.registered_buffers.append('bias_mask')
# register user specified buffer # register user specified buffer
for name in self.pruner.buffers: for name in self.pruner.buffers:
self.register_buffer(name, self.pruner.buffers[name].clone()) self.register_buffer(name, self.pruner.buffers[name].clone())
self.registered_buffers[name] = getattr(self, name) self.registered_buffers.append(name)
def get_registered_buffers(self):
buffers = {}
for name in self.registered_buffers:
buffers[name] = getattr(self, name)
return buffers
def forward(self, *inputs): def forward(self, *inputs):
mask = self.pruner.calc_mask(LayerInfo(self.name, self.module), self.config, **self.registered_buffers) mask = self.pruner.calc_mask(LayerInfo(self.name, self.module), self.config, **self.get_registered_buffers())
if mask is not None: if mask is not None:
self.weight_mask.copy_(mask['weight']) self.weight_mask.copy_(mask['weight'])
# apply mask to weight # apply mask to weight
...@@ -297,7 +304,8 @@ class Pruner(Compressor): ...@@ -297,7 +304,8 @@ class Pruner(Compressor):
""" """
_logger.info("compressing module %s.", layer.name) _logger.info("compressing module %s.", layer.name)
wrapper = PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self) wrapper = PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self)
assert hasattr(layer.module, 'weight') assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name
# move newly registered buffers to the same device of weight
wrapper.to(layer.module.weight.device) wrapper.to(layer.module.weight.device)
return wrapper return wrapper
...@@ -396,6 +404,7 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -396,6 +404,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
# config and pruner # config and pruner
self.config = config self.config = config
self.quantizer = quantizer self.quantizer = quantizer
self.registered_buffers = []
# register buffer and parameter # register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight # old_weight is used to store origin weight and weight is used to store quantized weight
...@@ -410,10 +419,15 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -410,10 +419,15 @@ class QuantizerModuleWrapper(torch.nn.Module):
self.module.register_buffer('weight', self.module.old_weight) self.module.register_buffer('weight', self.module.old_weight)
# register user specified buffer # register user specified buffer
self.registered_buffers = {}
for name in self.quantizer.buffers: for name in self.quantizer.buffers:
self.register_buffer(name, self.quantizer.buffers[name].clone()) self.register_buffer(name, self.quantizer.buffers[name].clone())
self.registered_buffers[name] = getattr(self, name) self.registered_buffers.append(name)
def get_registered_buffers(self):
buffers = {}
for name in self.registered_buffers:
buffers[name] = getattr(self, name)
return buffers
def forward(self, *inputs): def forward(self, *inputs):
if 'input' in self.config['quant_types']: if 'input' in self.config['quant_types']:
...@@ -423,7 +437,7 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -423,7 +437,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self.quantizer.quantize_input, self.quantizer.quantize_input,
self.config, self.config,
LayerInfo(self.name, self.module), LayerInfo(self.name, self.module),
**self.registered_buffers) **self.get_registered_buffers())
if 'weight' in self.config['quant_types'] and _check_weight(self.module): if 'weight' in self.config['quant_types'] and _check_weight(self.module):
new_weight = self.quantizer.quant_grad.apply( new_weight = self.quantizer.quant_grad.apply(
...@@ -432,7 +446,7 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -432,7 +446,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self.quantizer.quantize_weight, self.quantizer.quantize_weight,
self.config, self.config,
LayerInfo(self.name, self.module), LayerInfo(self.name, self.module),
**self.registered_buffers) **self.get_registered_buffers())
self.module.weight = new_weight self.module.weight = new_weight
result = self.module(*inputs) result = self.module(*inputs)
else: else:
...@@ -445,7 +459,7 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -445,7 +459,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self.quantizer.quantize_output, self.quantizer.quantize_output,
self.config, self.config,
LayerInfo(self.name, self.module), LayerInfo(self.name, self.module),
**self.registered_buffers) **self.get_registered_buffers())
return result return result
class Quantizer(Compressor): class Quantizer(Compressor):
......
...@@ -170,7 +170,7 @@ class AGP_Pruner(Pruner): ...@@ -170,7 +170,7 @@ class AGP_Pruner(Pruner):
if epoch > 0: if epoch > 0:
self.now_epoch = epoch self.now_epoch = epoch
for wrapper in self.get_modules_wrapper(): for wrapper in self.get_modules_wrapper():
wrapper.registered_buffers['if_calculated'].copy_(torch.tensor(0)) # pylint: disable=not-callable wrapper.if_calculated.copy_(torch.tensor(0)) # pylint: disable=not-callable
class SlimPruner(Pruner): class SlimPruner(Pruner):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment