Unverified Commit aa316742 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #233 from microsoft/master

merge master
parents 3fe117f0 24fa4619
...@@ -25,7 +25,10 @@ trial: ...@@ -25,7 +25,10 @@ trial:
memoryMB: 8196 memoryMB: 8196
image: msranni/nni:latest image: msranni/nni:latest
virtualCluster: nni virtualCluster: nni
nniManagerNFSMountPath: /home/user/mnt
containerNFSMountPath: /mnt/data/user
paiStoragePlugin: team_wise
paiConfig: paiConfig:
userName: your_account userName: your_account
passWord: your_pwd token: your_token
host: 0.0.0.0 host: 0.0.0.0
...@@ -30,10 +30,13 @@ trial: ...@@ -30,10 +30,13 @@ trial:
memoryMB: 8196 memoryMB: 8196
#The docker image to run nni job on pai #The docker image to run nni job on pai
image: msranni/nni:latest image: msranni/nni:latest
nniManagerNFSMountPath: /home/user/mnt
containerNFSMountPath: /mnt/data/user
paiStoragePlugin: team_wise
paiConfig: paiConfig:
#The username to login pai #The username to login pai
userName: username userName: username
#The password to login pai #The token to login pai
passWord: password token: token
#The host of restful server of pai #The host of restful server of pai
host: 10.10.10.10 host: 10.10.10.10
\ No newline at end of file
authorName: default
experimentName: example_FashionMNIST-network-morphism
trialConcurrency: 1
maxExecDuration: 24h
maxTrialNum: 10
#choice: local, remote, pai
trainingServicePlatform: paiYarn
#choice: true, false
useAnnotation: false
tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, NetworkMorphism
#SMAC (SMAC should be installed through nnictl)
builtinTunerName: NetworkMorphism
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
# for now, this tuner only supports cv domain
task: cv
#input image width
input_width: 28
#input image channel
input_channel: 1
#number of classes
n_output_node: 10
trial:
command: python3 FashionMNIST_keras.py
codeDir: .
gpuNum: 1
cpuNum: 1
memoryMB: 8196
#The docker image to run nni job on pai
image: msranni/nni:latest
paiYarnConfig:
#The username to login pai
userName: username
#The password to login pai
passWord: password
#The host of restful server of pai
host: 10.10.10.10
\ No newline at end of file
...@@ -30,10 +30,13 @@ trial: ...@@ -30,10 +30,13 @@ trial:
memoryMB: 8196 memoryMB: 8196
#The docker image to run nni job on pai #The docker image to run nni job on pai
image: msranni/nni:latest image: msranni/nni:latest
nniManagerNFSMountPath: /home/user/mnt
containerNFSMountPath: /mnt/data/user
paiStoragePlugin: team_wise
paiConfig: paiConfig:
#The username to login pai #The username to login pai
userName: username userName: username
#The password to login pai #The token to login pai
passWord: password token: token
#The host of restful server of pai #The host of restful server of pai
host: 10.10.10.10 host: 10.10.10.10
\ No newline at end of file
authorName: default
experimentName: example_cifar10-network-morphism
trialConcurrency: 1
maxExecDuration: 24h
maxTrialNum: 10
#choice: local, remote, pai
trainingServicePlatform: paiYarn
#choice: true, false
useAnnotation: false
tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, NetworkMorphism
#SMAC (SMAC should be installed through nnictl)
builtinTunerName: NetworkMorphism
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
# for now, this tuner only supports cv domain
task: cv
#input image width
input_width: 32
#input image channel
input_channel: 3
#number of classes
n_output_node: 10
trial:
command: python3 cifar10_keras.py
codeDir: .
gpuNum: 1
cpuNum: 1
memoryMB: 8196
#The docker image to run nni job on pai
image: msranni/nni:latest
paiYarnConfig:
#The username to login pai
userName: username
#The password to login pai
passWord: password
#The host of restful server of pai
host: 10.10.10.10
\ No newline at end of file
...@@ -23,10 +23,13 @@ trial: ...@@ -23,10 +23,13 @@ trial:
memoryMB: 8196 memoryMB: 8196
#The docker image to run nni job on pai #The docker image to run nni job on pai
image: msranni/nni:latest image: msranni/nni:latest
nniManagerNFSMountPath: /home/user/mnt
containerNFSMountPath: /mnt/data/user
paiStoragePlugin: team_wise
paiConfig: paiConfig:
#The username to login pai #The username to login pai
userName: username userName: username
#The password to login pai #The token to login pai
passWord: password token: token
#The host of restful server of pai #The host of restful server of pai
host: 10.10.10.10 host: 10.10.10.10
\ No newline at end of file
authorName: default
experimentName: example_sklearn
trialConcurrency: 1
maxExecDuration: 1h
maxTrialNum: 100
#choice: local, remote, pai
trainingServicePlatform: paiYarn
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner,MetisTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName: TPE
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
trial:
command: python3 main.py
codeDir: .
gpuNum: 0
cpuNum: 1
memoryMB: 8196
#The docker image to run nni job on pai
image: msranni/nni:latest
paiYarnConfig:
#The username to login pai
userName: username
#The password to login pai
passWord: password
#The host of restful server of pai
host: 10.10.10.10
\ No newline at end of file
...@@ -23,10 +23,13 @@ trial: ...@@ -23,10 +23,13 @@ trial:
memoryMB: 8196 memoryMB: 8196
#The docker image to run nni job on pai #The docker image to run nni job on pai
image: msranni/nni:latest image: msranni/nni:latest
nniManagerNFSMountPath: /home/user/mnt
containerNFSMountPath: /mnt/data/user
paiStoragePlugin: team_wise
paiConfig: paiConfig:
#The username to login pai #The username to login pai
userName: username userName: username
#The password to login pai #The token to login pai
passWord: password token: token
#The host of restful server of pai #The host of restful server of pai
host: 10.10.10.10 host: 10.10.10.10
\ No newline at end of file
authorName: default
experimentName: example_sklearn
trialConcurrency: 1
maxExecDuration: 1h
maxTrialNum: 100
#choice: local, remote, pai
trainingServicePlatform: paiYarn
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName: TPE
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
trial:
command: python3 main.py
codeDir: .
gpuNum: 0
cpuNum: 1
memoryMB: 8196
#The docker image to run nni job on pai
image: msranni/nni:latest
paiYarnConfig:
#The username to login pai
userName: username
#The password to login pai
passWord: password
#The host of restful server of pai
host: 10.10.10.10
\ No newline at end of file
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
'use strict'; 'use strict';
import * as assert from 'assert'; import * as assert from 'assert';
import * as JSON5 from 'json5';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import * as component from '../common/component'; import * as component from '../common/component';
...@@ -132,7 +131,7 @@ class NNIDataStore implements DataStore { ...@@ -132,7 +131,7 @@ class NNIDataStore implements DataStore {
} }
public async storeMetricData(trialJobId: string, data: string): Promise<void> { public async storeMetricData(trialJobId: string, data: string): Promise<void> {
const metrics: MetricData = JSON5.parse(data); const metrics: MetricData = JSON.parse(data);
// REQUEST_PARAMETER is used to request new parameters for multiphase trial job, // REQUEST_PARAMETER is used to request new parameters for multiphase trial job,
// it is not metrics, so it is skipped here. // it is not metrics, so it is skipped here.
if (metrics.type === 'REQUEST_PARAMETER') { if (metrics.type === 'REQUEST_PARAMETER') {
...@@ -141,7 +140,7 @@ class NNIDataStore implements DataStore { ...@@ -141,7 +140,7 @@ class NNIDataStore implements DataStore {
} }
assert(trialJobId === metrics.trial_job_id); assert(trialJobId === metrics.trial_job_id);
try { try {
await this.db.storeMetricData(trialJobId, JSON5.stringify({ await this.db.storeMetricData(trialJobId, JSON.stringify({
trialJobId: metrics.trial_job_id, trialJobId: metrics.trial_job_id,
parameterId: metrics.parameter_id, parameterId: metrics.parameter_id,
type: metrics.type, type: metrics.type,
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import * as assert from 'assert'; import * as assert from 'assert';
import * as fs from 'fs'; import * as fs from 'fs';
import * as JSON5 from 'json5';
import * as path from 'path'; import * as path from 'path';
import * as sqlite3 from 'sqlite3'; import * as sqlite3 from 'sqlite3';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
...@@ -203,10 +202,10 @@ class SqlDB implements Database { ...@@ -203,10 +202,10 @@ class SqlDB implements Database {
public storeMetricData(trialJobId: string, data: string): Promise<void> { public storeMetricData(trialJobId: string, data: string): Promise<void> {
const sql: string = 'insert into MetricData values (?,?,?,?,?,?)'; const sql: string = 'insert into MetricData values (?,?,?,?,?,?)';
const json: MetricDataRecord = JSON5.parse(data); const json: MetricDataRecord = JSON.parse(data);
const args: any[] = [Date.now(), json.trialJobId, json.parameterId, json.type, json.sequence, JSON5.stringify(json.data)]; const args: any[] = [Date.now(), json.trialJobId, json.parameterId, json.type, json.sequence, JSON.stringify(json.data)];
this.log.trace(`storeMetricData: SQL: ${sql}, args: ${JSON5.stringify(args)}`); this.log.trace(`storeMetricData: SQL: ${sql}, args: ${JSON.stringify(args)}`);
const deferred: Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
this.db.run(sql, args, (err: Error | null) => { this.resolve(deferred, err); }); this.db.run(sql, args, (err: Error | null) => { this.resolve(deferred, err); });
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
"express": "^4.16.3", "express": "^4.16.3",
"express-joi-validator": "^2.0.0", "express-joi-validator": "^2.0.0",
"js-base64": "^2.4.9", "js-base64": "^2.4.9",
"json5": "^2.1.1",
"kubernetes-client": "^6.5.0", "kubernetes-client": "^6.5.0",
"rx": "^4.1.0", "rx": "^4.1.0",
"sqlite3": "^4.0.2", "sqlite3": "^4.0.2",
...@@ -36,7 +35,6 @@ ...@@ -36,7 +35,6 @@
"@types/express": "^4.16.0", "@types/express": "^4.16.0",
"@types/glob": "^7.1.1", "@types/glob": "^7.1.1",
"@types/js-base64": "^2.3.1", "@types/js-base64": "^2.3.1",
"@types/json5": "^0.0.30",
"@types/mocha": "^5.2.5", "@types/mocha": "^5.2.5",
"@types/node": "10.12.18", "@types/node": "10.12.18",
"@types/request": "^2.47.1", "@types/request": "^2.47.1",
......
...@@ -34,7 +34,7 @@ export class NNIRestServer extends RestServer { ...@@ -34,7 +34,7 @@ export class NNIRestServer extends RestServer {
*/ */
protected registerRestHandler(): void { protected registerRestHandler(): void {
this.app.use(express.static('static')); this.app.use(express.static('static'));
this.app.use(bodyParser.json()); this.app.use(bodyParser.json({limit: '50mb'}));
this.app.use(this.API_ROOT_URL, createRestHandler(this)); this.app.use(this.API_ROOT_URL, createRestHandler(this));
this.app.use(this.LOGS_ROOT_URL, express.static(getLogDir())); this.app.use(this.LOGS_ROOT_URL, express.static(getLogDir()));
this.app.get('*', (req: express.Request, res: express.Response) => { this.app.get('*', (req: express.Request, res: express.Response) => {
......
...@@ -157,10 +157,6 @@ ...@@ -157,10 +157,6 @@
version "7.0.3" version "7.0.3"
resolved "https://registry.yarnpkg.com/@types/json-schema/-/json-schema-7.0.3.tgz#bdfd69d61e464dcc81b25159c270d75a73c1a636" resolved "https://registry.yarnpkg.com/@types/json-schema/-/json-schema-7.0.3.tgz#bdfd69d61e464dcc81b25159c270d75a73c1a636"
"@types/json5@^0.0.30":
version "0.0.30"
resolved "https://registry.yarnpkg.com/@types/json5/-/json5-0.0.30.tgz#44cb52f32a809734ca562e685c6473b5754a7818"
"@types/mime@*": "@types/mime@*":
version "2.0.0" version "2.0.0"
resolved "https://registry.yarnpkg.com/@types/mime/-/mime-2.0.0.tgz#5a7306e367c539b9f6543499de8dd519fac37a8b" resolved "https://registry.yarnpkg.com/@types/mime/-/mime-2.0.0.tgz#5a7306e367c539b9f6543499de8dd519fac37a8b"
...@@ -2380,12 +2376,6 @@ json-stringify-safe@~5.0.1: ...@@ -2380,12 +2376,6 @@ json-stringify-safe@~5.0.1:
version "5.0.1" version "5.0.1"
resolved "https://registry.yarnpkg.com/json-stringify-safe/-/json-stringify-safe-5.0.1.tgz#1296a2d58fd45f19a0f6ce01d65701e2c735b6eb" resolved "https://registry.yarnpkg.com/json-stringify-safe/-/json-stringify-safe-5.0.1.tgz#1296a2d58fd45f19a0f6ce01d65701e2c735b6eb"
json5@^2.1.1:
version "2.1.1"
resolved "https://registry.yarnpkg.com/json5/-/json5-2.1.1.tgz#81b6cb04e9ba496f1c7005d07b4368a2638f90b6"
dependencies:
minimist "^1.2.0"
jsonparse@^1.2.0: jsonparse@^1.2.0:
version "1.3.1" version "1.3.1"
resolved "https://registry.yarnpkg.com/jsonparse/-/jsonparse-1.3.1.tgz#3f4dae4a91fac315f71062f8521cc239f1366280" resolved "https://registry.yarnpkg.com/jsonparse/-/jsonparse-1.3.1.tgz#3f4dae4a91fac315f71062f8521cc239f1366280"
......
...@@ -557,7 +557,8 @@ class BOHB(MsgDispatcherBase): ...@@ -557,7 +557,8 @@ class BOHB(MsgDispatcherBase):
Data type not supported Data type not supported
""" """
logger.debug('handle report metric data = %s', data) logger.debug('handle report metric data = %s', data)
if 'value' in data:
data['value'] = json_tricks.loads(data['value'])
if data['type'] == MetricType.REQUEST_PARAMETER: if data['type'] == MetricType.REQUEST_PARAMETER:
assert multi_phase_enabled() assert multi_phase_enabled()
assert data['trial_job_id'] is not None assert data['trial_job_id'] is not None
...@@ -627,6 +628,8 @@ class BOHB(MsgDispatcherBase): ...@@ -627,6 +628,8 @@ class BOHB(MsgDispatcherBase):
AssertionError AssertionError
data doesn't have required key 'parameter' and 'value' data doesn't have required key 'parameter' and 'value'
""" """
for entry in data:
entry['value'] = json_tricks.loads(entry['value'])
_completed_num = 0 _completed_num = 0
for trial_info in data: for trial_info in data:
logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data)) logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data))
......
from .compressor import ModelSpeedup
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from .infer_shape import ModuleMasks
_logger = logging.getLogger(__name__)
replace_module = {
'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask),
'Conv2d': lambda module, mask: replace_conv2d(module, mask),
'MaxPool2d': lambda module, mask: no_replace(module, mask),
'AvgPool2d': lambda module, mask: no_replace(module, mask),
'ReLU': lambda module, mask: no_replace(module, mask),
'Linear': lambda module, mask: replace_linear(module, mask)
}
def no_replace(module, mask):
"""
No need to replace
"""
_logger.debug("no need to replace")
return module
def replace_linear(linear, mask):
"""
Parameters
----------
linear : torch.nn.Linear
The linear module to be replace
mask : ModuleMasks
The masks of this module
Returns
-------
torch.nn.Linear
The new linear module
"""
assert isinstance(mask, ModuleMasks)
assert mask.input_mask is not None
assert mask.output_mask is None
assert not mask.param_masks
index = mask.input_mask.mask_index[-1]
in_features = index.size()[0]
_logger.debug("replace linear with new in_features: %d", in_features)
new_linear = torch.nn.Linear(in_features=in_features,
out_features=linear.out_features,
bias=linear.bias is not None)
new_linear.to(linear.weight.device)
new_linear.weight.data = torch.index_select(linear.weight.data, -1, index.to(linear.weight.device))
if linear.bias is not None:
new_linear.bias.data.copy_(linear.bias.data)
return new_linear
def replace_batchnorm2d(norm, mask):
"""
Parameters
----------
norm : torch.nn.BatchNorm2d
The batchnorm module to be replace
mask : ModuleMasks
The masks of this module
Returns
-------
torch.nn.BatchNorm2d
The new batchnorm module
"""
assert isinstance(mask, ModuleMasks)
assert 'weight' in mask.param_masks and 'bias' in mask.param_masks
index = mask.param_masks['weight'].mask_index[0]
num_features = index.size()[0]
_logger.debug("replace batchnorm2d with num_features: %d", num_features)
new_norm = torch.nn.BatchNorm2d(num_features=num_features,
eps=norm.eps,
momentum=norm.momentum,
affine=norm.affine,
track_running_stats=norm.track_running_stats)
# assign weights
new_norm.weight.data = torch.index_select(norm.weight.data, 0, index)
new_norm.bias.data = torch.index_select(norm.bias.data, 0, index)
if norm.track_running_stats:
new_norm.running_mean.data = torch.index_select(norm.running_mean.data, 0, index)
new_norm.running_var.data = torch.index_select(norm.running_var.data, 0, index)
return new_norm
def replace_conv2d(conv, mask):
"""
Parameters
----------
conv : torch.nn.Conv2d
The conv2d module to be replaced
mask : ModuleMasks
The masks of this module
Returns
-------
torch.nn.Conv2d
The new conv2d module
"""
assert isinstance(mask, ModuleMasks)
if mask.input_mask is None:
in_channels = conv.in_channels
else:
in_channels_index = mask.input_mask.mask_index[1]
in_channels = in_channels_index.size()[0]
if mask.output_mask is None:
out_channels = conv.out_channels
else:
out_channels_index = mask.output_mask.mask_index[1]
out_channels = out_channels_index.size()[0]
_logger.debug("replace conv2d with in_channels: %d, out_channels: %d", in_channels, out_channels)
new_conv = torch.nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=1, # currently only support groups is 1
bias=conv.bias is not None,
padding_mode=conv.padding_mode)
new_conv.to(conv.weight.device)
tmp_weight_data = tmp_bias_data = None
if mask.output_mask is not None:
tmp_weight_data = torch.index_select(conv.weight.data, 0, out_channels_index)
if conv.bias is not None:
tmp_bias_data = torch.index_select(conv.bias.data, 0, out_channels_index)
# NOTE: does not support group
if mask.input_mask is not None:
tmp_weight_data = torch.index_select(conv.weight.data if tmp_weight_data is None else tmp_weight_data,
1, in_channels_index)
assert tmp_weight_data is not None, "Conv2d weight should be updated based on masks"
new_conv.weight.data.copy_(tmp_weight_data)
if conv.bias is not None:
new_conv.bias.data.copy_(conv.bias.data if tmp_bias_data is None else tmp_bias_data)
return new_conv
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import queue
import re
import torch
from .compress_modules import replace_module
from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape
_logger = logging.getLogger(__name__)
def get_module_by_name(model, module_name):
"""
Get a module specified by its module name
Parameters
----------
model : pytorch model
the pytorch model from which to get its module
module_name : str
the name of the required module
Returns
-------
module, module
the parent module of the required module, the required module
"""
name_list = module_name.split(".")
for name in name_list[:-1]:
model = getattr(model, name)
leaf_module = getattr(model, name_list[-1])
return model, leaf_module
class GNode:
"""
It is used to represent a node in model graph, in this graph a module is a node,
a function out of module (in ```forward``` function) could also be a node.
"""
def __init__(self, node_name, node_type, op_type, inputs, outputs, nodes):
"""
Parameters
----------
node_name : str
It is module name if the node is a module, it is ```scope_name.node_kind.seq``` if it is a func
node_type : str
It only has two options: `module` or `func`
op_type : str
The operation type of the module or func
inputs : list of str
All the inputs of this node, each element is debugName of one input
outputs : list of str
All the outputs of this node, each element is debugName of one output
nodes : list of node
All the trace graph nodes included in this module or func
"""
self.name = node_name
self.type = node_type
self.op_type = op_type
self.inputs = inputs
self.outputs = outputs
self.nodes = nodes
# store supplementary information for different op types
# for example, for ```view``` it stores the shape of its input and output
self.auxiliary = None
class ModelSpeedup:
"""
This class is to speedup the model with provided weight mask
"""
def __init__(self, model, dummy_input, masks_file):
"""
Parameters
----------
model : pytorch model
The model user wants to speed up
dummy_input : pytorch tensor
The dummy input for ```jit.trace```, users should put it on right device before pass in
masks_file : str
The path of user provided mask file
"""
self.bound_model = model
self.dummy_input = dummy_input
self.masks = torch.load(masks_file)
self.is_training = model.training
# to obtain forward graph, model should be in ```eval``` mode
if self.is_training:
model.eval()
self.trace_graph = torch.jit.trace(model, dummy_input)
if self.is_training:
model.train()
self.inferred_masks = dict() # key: module_name, value: ModuleMasks
self.g_nodes = list()
self.global_count = 0
self.name_to_gnode, self.input_to_gnode, self.output_to_gnode = self._build_graph()
def _build_index_for_gnodes(self, g_nodes):
"""
Build indexes for quick search
Parameters
----------
g_nodes : list of GNode
All the g_node in processed model graph
Returns
-------
dict
use name to index g_nodes, key: node name, value: g_node
dict
use input (its name) to index g_nodes,
key: input, value: list of g_nodes that take this input
dict
use output (its name) to index g_nodes,
key: output, value: g_node that generates this output
"""
name_to_gnode = dict()
input_to_gnode = dict()
output_to_gnode = dict()
for node in g_nodes:
name_to_gnode[node.name] = node
for _input in node.inputs:
if _input in input_to_gnode:
input_to_gnode[_input].append(node)
else:
input_to_gnode[_input] = [node]
for output in node.outputs:
assert not output in output_to_gnode, \
"One output cannot be generated by multiple nodes"
output_to_gnode[output] = node
return name_to_gnode, input_to_gnode, output_to_gnode
def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node):
"""
For trace graph nodes, some nodes are not in modules, these nodes are usually generated by
the functions directly called in module ```forward```. For such nodes, some of them are
trivial op which are label by ```prim::```, some of them are not such ops which is call
non-prim ops. This function is to merge neighbor prim ops to a non-prim op, to construct
a GNode.
Parameters
----------
node : trace graph node
The non-prim node to expand
nodes : list of trace graph node
All the trace graph nodes within the same scope as the non-prim node
input_to_node : dict
key: input name, value: a node that uses this input
output_to_node : dict
key: output name, value: a node that generates this output
Returns
-------
GNode
the expanded non-prim node in GNode format
"""
# TODO: scope name could be empty
node_name = '.'.join([node.scopeName(), node.kind(), str(self.global_count)])
_logger.debug("expand non-prim node, node name: %s", node_name)
self.global_count += 1
op_type = node.kind()
node_group = [node]
inputs = list()
outputs = list()
node_queue = queue.Queue()
node_queue.put(node)
while not node_queue.empty():
curr_node = node_queue.get()
for _input in curr_node.inputs():
input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes:
predecessor_node = output_to_node[input_name]
if predecessor_node.kind().startswith('prim::'):
node_group.append(predecessor_node)
node_queue.put(predecessor_node)
else:
inputs.append(input_name)
else:
inputs.append(input_name)
for output in node.outputs():
outputs.append(output.debugName())
g_node = GNode(node_name, 'func', op_type, inputs, outputs, node_group)
return g_node
def _extract_shape_info(self, node):
"""
Extract the shape information of ```aten::view``` node
Parameters
----------
node : trace graph node
It should be ```aten::view``` node
Returns
-------
dict
Include shape of input tensor and shape of output tensor
"""
t_input = None
for _input in node.inputs():
t_input = _input
break
t_output = node.output()
assert isinstance(t_input.type(), torch._C.TensorType)
assert isinstance(t_output.type(), torch._C.TensorType)
in_shape = t_input.type().sizes()
out_shape = t_output.type().sizes()
return {'in_shape': in_shape, 'out_shape': out_shape}
def _extract_leaf_modules(self, graph):
"""
Extract leaf modules from the given graph. Leaf module means it does not have submodules.
To extract leaf modules because only leaf module can be replaced. And shape inference can
be done in leaf module level. Other shape inference is done in lower level i.e.,
operation level.
Parameters
----------
graph : jit trace graph
the graph generated from jit trace
Returns
-------
list
a list of scope name of all the leaf modules
"""
pieces = [] # each element is a dict
for node in graph.nodes():
scope_name = node.scopeName()
if scope_name == '':
continue
segs = scope_name.split('/')
segs_len = len(segs)
# increase the length of `pieces` if not enough
for _ in range(segs_len - len(pieces)):
pieces.append({})
# process internal segments of the scope name
# 'L' means leaf segment
# 'I' means internal segment
# internal segment can replace leaf segment at the same position of `pieces`
for i, seg in enumerate(segs[:-1]):
seg_name_dict = pieces[i]
if seg in seg_name_dict:
if seg_name_dict[seg][0] == 'L':
seg_name_dict[seg] = ('I', node)
else:
seg_name_dict[seg] = ('I', node)
# process the leaf segment of the scope name
last_segs_dict = pieces[len(segs) - 1]
if not segs[-1] in last_segs_dict:
last_segs_dict[segs[-1]] = ('L', node)
# traverse `pieces` to obtain all the leaf modules which are labeled with 'L'
leaf_modules = []
for piece in pieces:
for _, value in piece.items():
if value[0] == 'L':
assert value[1].scopeName() not in leaf_modules
# if this is a leaf module, the last segment of its scope name
# must be in pattern `xxx[xxx]`
if value[1].scopeName()[-1] == ']':
leaf_modules.append(value[1].scopeName())
return leaf_modules
def _build_graph(self):
"""
Build graph using our defined format from jit trace.
There are basically three steps: first, construct necessary information (data structures),
second, extract all the modules to convert to GNode, Third, extract all functions to convert
to GNode.
Returns
-------
dict
use name to index g_nodes, key: node name, value: g_node
dict
use input (its name) to index g_nodes,
key: input, value: list of g_nodes that take this input
dict
use output (its name) to index g_nodes,
key: output, value: g_node that generates this output
"""
graph = self.trace_graph.graph
# if torch 1.4.0 is used, consider run torch._C._jit_pass_inline(graph) here
_logger.debug(graph)
# build output mapping, from output debugName to its node
output_to_node = dict()
# build input mapping, from input debugName to its node
input_to_node = dict()
# build module mapping, from module name to all nodes (as list) under this module scope
module_to_nodes = dict()
# module name to its type
module_to_type = dict()
# the mapping of function (non-module in forward) to nodes, key is scope name
func_to_nodes = dict()
graph_inputs = list()
graph_outputs = list()
for _input in graph.inputs():
graph_inputs.append(_input.debugName())
for output in graph.outputs():
graph_outputs.append(output.debugName())
leaf_modules = self._extract_leaf_modules(graph)
_logger.debug(leaf_modules)
for node in graph.nodes():
# populate output_to_node and input_to_node
for output in node.outputs():
output_name = output.debugName()
output_to_node[output_name] = node
for _input in node.inputs():
input_name = _input.debugName()
input_to_node[input_name] = node
scope_name = node.scopeName() # example: scope_name, 'MyCell/Linear[linear]'
# if module_name is empty, it is not a module
if not scope_name in leaf_modules:
if scope_name == '':
continue
else:
if scope_name in func_to_nodes:
func_to_nodes[scope_name].append(node)
else:
func_to_nodes[scope_name] = [node]
else:
module_name_slices = re.findall(r'\[(.*?)\]', scope_name)
module_name = '.'.join(module_name_slices)
scope_slice = scope_name.split('/')[-1]
module_type = scope_slice.split('[')[0]
module_to_type[module_name] = module_type
if module_name in module_to_nodes:
module_to_nodes[module_name].append(node)
else:
module_to_nodes[module_name] = [node]
# construct GNode from module
for module_name, nodes in module_to_nodes.items():
inputs = set()
outputs = set()
for node in nodes:
for output in node.outputs():
outputs.add(output.debugName())
for _input in node.inputs():
inputs.add(_input.debugName())
m_inputs = list()
m_outputs = list()
for output in outputs:
# TODO: one input could be the input of multiple nodes
if not output in input_to_node and output in graph_outputs:
m_outputs.append(output)
elif not input_to_node[output] in nodes:
m_outputs.append(output)
for _input in inputs:
if not _input in output_to_node and _input in graph_inputs:
m_inputs.append(_input)
elif not output_to_node[_input] in nodes:
m_inputs.append(_input)
if module_name == '':
_logger.warning("module_name is empty string")
g_node = GNode(module_name, 'module', module_to_type[module_name], m_inputs, m_outputs, nodes)
self.g_nodes.append(g_node)
# each scope_name may have multiple funcs, we split them and create GNode for each of them
for scope_name, nodes in func_to_nodes.items():
# extract non prim:: nodes
non_prim_nodes = list()
for node in nodes:
if not node.kind().startswith('prim::'):
non_prim_nodes.append(node)
# for each non prim node, expand it has a GNode
for node in non_prim_nodes:
g_node = self._expand_non_prim_node(node, nodes, input_to_node, output_to_node)
self.g_nodes.append(g_node)
# get shape infor for view (aten::view) func
if g_node.op_type == 'aten::view':
g_node.auxiliary = self._extract_shape_info(node)
# build index for g_nodes
name_to_gnode, input_to_gnode, output_to_gnode = self._build_index_for_gnodes(self.g_nodes)
return name_to_gnode, input_to_gnode, output_to_gnode
def _find_predecessors(self, module_name):
"""
Find predecessor GNode of the given GNode
Parameters
----------
module_name : str
The name of the GNode
Returns
-------
list
a list of GNodes who are the given GNode's predecessor
"""
predecessors = []
for _input in self.name_to_gnode[module_name].inputs:
if not _input in self.output_to_gnode:
_logger.debug("cannot find gnode with %s as its output", _input)
else:
g_node = self.output_to_gnode[_input]
predecessors.append(g_node.name)
return predecessors
def _find_successors(self, module_name):
"""
Find successor GNodes of the given GNode
Parameters
----------
module_name : str
The name of the GNode
Returns
-------
list
a list of GNodes who are the given GNode's successor
"""
successors = []
for output in self.name_to_gnode[module_name].outputs:
assert output in self.input_to_gnode, "No gnode with input {}".format(output)
g_nodes = self.input_to_gnode[output]
for g_node in g_nodes:
successors.append(g_node.name)
return successors
def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=None):
"""
Infer input shape / output shape based on the module's weight mask / input shape / output shape.
For a module:
Infer its input and output shape from its weight mask
Infer its output shape from its input shape
Infer its input shape from its output shape
If its input shape is changed, continue infering its predecessors
If its output shape is changed, continue infering its successors
Parameters
----------
module_name : str
The name of the GNode
mask : tensor of mask or ModuleMasks
Mask of the weights in this GNode (i.e., module)
in_shape : ModuleMasks
Input shape of this GNode
out_shape : ModuleMasks
Output shape of this GNode
"""
input_cmask = output_cmask = None
if module_name in self.inferred_masks:
module_masks = self.inferred_masks[module_name]
else:
module_masks = ModuleMasks(module_name)
self.inferred_masks[module_name] = module_masks
m_type = self.name_to_gnode[module_name].op_type
_logger.debug("infer mask of module %s with op_type %s", module_name, m_type)
if mask is not None:
_logger.debug("mask is not None")
if not m_type in infer_from_mask:
raise RuntimeError(
"Has not supported infering input/output shape from mask for module/function: `{}`, {}"
.format(m_type, module_name))
input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask)
if in_shape is not None:
_logger.debug("in_shape is not None")
if not m_type in infer_from_inshape:
raise RuntimeError(
"Has not supported infering output shape from input shape for module/function: `{}`, {}"
.format(m_type, module_name))
if m_type == 'aten::view':
output_cmask = infer_from_inshape[m_type](module_masks,
in_shape,
self.name_to_gnode[module_name].auxiliary)
else:
output_cmask = infer_from_inshape[m_type](module_masks, in_shape)
if out_shape is not None:
_logger.debug("out_shape is not None")
if not m_type in infer_from_outshape:
raise RuntimeError(
"Has not supported infering input shape from output shape for module/function: `{}`, {}"
.format(m_type, module_name))
input_cmask = infer_from_outshape[m_type](module_masks, out_shape)
if input_cmask:
predecessors = self._find_predecessors(module_name)
for _module_name in predecessors:
self.infer_module_mask(_module_name, out_shape=input_cmask)
if output_cmask:
successors = self._find_successors(module_name)
for _module_name in successors:
self.infer_module_mask(_module_name, in_shape=output_cmask)
def infer_modules_masks(self):
"""
Do shape inference of involved modules, including the shape of weights, inputs, output
"""
for module_name, mask in self.masks.items():
self.infer_module_mask(module_name, mask=mask)
def replace_compressed_modules(self):
"""
Replace all the modules that have changed (weights/inputs/output) shape.
The new module is created using the same arguments of the to-be-replaced module,
and correctly inherits its weights.
NOTE: ```func``` type cannot be replaced as it is not a module, thus, one limitation
is that ```func``` should be not required to be replaced.
"""
for module_name in self.inferred_masks:
g_node = self.name_to_gnode[module_name]
_logger.debug("replace %s, in %s type, with op_type %s",
module_name, g_node.type, g_node.op_type)
if g_node.type == 'module':
super_module, leaf_module = get_module_by_name(self.bound_model, module_name)
m_type = g_node.op_type
if not m_type in replace_module:
raise RuntimeError("Has not supported replacing the module: `{}`".format(m_type))
_logger.info("replace module (name: %s, op_type: %s)", module_name, m_type)
compressed_module = replace_module[m_type](leaf_module, self.inferred_masks[module_name])
setattr(super_module, module_name.split('.')[-1], compressed_module)
elif g_node.type == 'func':
_logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type",
module_name, g_node.op_type)
else:
raise RuntimeError("Unsupported GNode type: {}".format(g_node.type))
def speedup_model(self):
"""
There are basically two steps:
first, do mask/shape inference,
second, replace modules
"""
_logger.info("start to speed up the model")
_logger.info("infer module masks...")
self.infer_modules_masks()
_logger.info("replace compressed modules...")
self.replace_compressed_modules()
_logger.info("speedup done")
# resume the model mode to that before the model is speed up
if self.is_training:
self.bound_model.train()
else:
self.bound_model.eval()
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
For each operation or module, there are two functions.
One is given output shape, infer its input shape and initialization parameters (e.g., weight's shape)
The other is given input shape, infer its output shape and initialization parameters (e.g., weight's shape)
"""
import torch
class CoarseMask:
"""
Coarse grained mask for a given tensor, here tensor could be weights,
input tensor, or output tensor
"""
def __init__(self, num_dim):
"""
Parameters
----------
num_dim : int
The number of dimensions of the tensor that will be masked
"""
self.mask_index = [None for _ in range(num_dim)]
def add_index_mask(self, dim, index):
"""
Add mask for the specified dimension
Parameters
----------
dim : int
The dimension to add mask
index : tensor
The mask for this dimension, its a 1 dimension tensor which specifies
the index of the elements that are not pruned
"""
self.mask_index[dim] = index
@staticmethod
def merge_index(index_a, index_b):
"""
Parameters
----------
index_a : tensor
One index (1-dimension) tensor
index_b : tensor
The other index (1-dimension) tensor
Returns
-------
tensor
The merged index (1-dimension) tensor
"""
s = set()
for num in index_a:
s.add(num)
for num in index_b:
s.add(num)
return torch.tensor(sorted(s)) # pylint: disable=not-callable
def merge(self, cmask):
"""
Merge another CoarseMask
Parameters
----------
cmask : CoarseMask
Another CoarseMask to merge
Returns
-------
list
The member variable ```mask_index```
"""
assert isinstance(cmask, CoarseMask)
assert len(self.mask_index) == len(cmask.mask_index), \
"Only masks with the same number of dimensions can be merged"
for i, index in enumerate(self.mask_index):
if index is None:
self.mask_index[i] = cmask.mask_index[i]
elif cmask.mask_index[i] is not None:
self.mask_index[i] = CoarseMask.merge_index(self.mask_index[i],
cmask.mask_index[i])
return self.mask_index
class ModuleMasks:
"""
The masks of a module, including the masks for weights, inputs, output
"""
def __init__(self, module_name):
"""
Parameters
----------
module_name : str
The name of the module or function
"""
self.module_name = module_name
self.param_masks = dict()
self.input_mask = None
self.output_mask = None
def set_param_masks(self, name, mask):
"""
Parameters
----------
name : str
The name of the weight
mask : CoarseMask
The mask for this weight
"""
self.param_masks[name] = mask
def set_input_mask(self, mask):
"""
Parameters
----------
mask : CoarseMask
The mask for input
"""
self.input_mask = mask
def set_output_mask(self, mask):
"""
Parameters
----------
mask : CoarseMask
The mask for output
"""
self.output_mask = mask
"""
Infer input and output shape of a module/function from its weight mask
"""
infer_from_mask = {
'BatchNorm2d': lambda module_masks, mask: batchnorm2d_mask(module_masks, mask),
'Conv2d': lambda module_masks, mask: conv2d_mask(module_masks, mask)
}
"""
Infer output and weight shape of a module/function from its input shape
"""
infer_from_inshape = {
'ReLU': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::relu': lambda module_masks, mask: relu_inshape(module_masks, mask),
'Conv2d': lambda module_masks, mask: conv2d_inshape(module_masks, mask),
'MaxPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::avg_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'AvgPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::size': lambda module_masks, mask: size_inshape(module_masks, mask),
'aten::view': lambda module_masks, mask, shape: view_inshape(module_masks, mask, shape),
'Linear': lambda module_masks, mask: linear_inshape(module_masks, mask),
'BatchNorm2d': lambda module_masks, mask: batchnorm2d_inshape(module_masks, mask)
}
"""
Infer input and weight shape of a module/function from its output shape
"""
infer_from_outshape = {
'Conv2d': lambda module_masks, mask: conv2d_outshape(module_masks, mask)
}
def batchnorm2d_inshape(module_masks, mask):
"""
We assume only the second dimension has coarse grained mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the batchnorm2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
weight_cmask = CoarseMask(num_dim=1)
weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
module_masks.set_param_masks('weight', weight_cmask)
module_masks.set_param_masks('bias', weight_cmask)
return mask
def linear_inshape(module_masks, mask):
"""
Coarse grained input mask does not change the shape of weights and output tensor
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the linear
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor, ```None``` means shape of output tensor is not changed
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[0] is None
assert module_masks.input_mask is None
module_masks.set_input_mask(mask)
return None
def view_inshape(module_masks, mask, shape):
"""
This is a limited support
TODO: consider replace tensor.view with nn.Flatten, because tensor.view is not
included in module, thus, cannot be replaced by our framework.
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the ```view``` op
mask : CoarseMask
The mask of its input tensor
shape : dict
Original shape of its input and output tensors
Returns
-------
CoarseMask
The mask of its output tensor
"""
# NOTE: the case constrained by the following four asserts
assert shape['in_shape'][0] == shape['out_shape'][0]
assert len(shape['in_shape']) == 4
assert len(shape['out_shape']) == 2
assert shape['out_shape'][1] == shape['in_shape'][1]*shape['in_shape'][2]*shape['in_shape'][3]
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
assert module_masks.input_mask is None
module_masks.set_input_mask(mask)
output_cmask = CoarseMask(num_dim=2)
index = []
step_size = shape['in_shape'][2] * shape['in_shape'][3]
for loc in mask.mask_index[1]:
index.extend([loc * step_size + i for i in range(step_size)])
output_cmask.add_index_mask(dim=1, index=torch.tensor(index)) # pylint: disable=not-callable
module_masks.set_output_mask(output_cmask)
return output_cmask
def size_inshape(module_masks, mask):
"""
No need to do anything for this ```size``` op
"""
return None
def maxpool2d_inshape(module_masks, mask):
"""
Assume only the second dimension is masked
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the maxpool2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
assert module_masks.input_mask is None
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
return mask
def relu_inshape(module_masks, mask):
"""
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the relu
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
# TODO: double check this assert, is it possible that a module is passed twice
assert module_masks.input_mask is None, "A relu op can only be processed once"
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
return mask
def batchnorm2d_mask(module_masks, mask):
"""
Infer input and output shape from weight mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the batchnorm2d
mask : dict
The mask of its weights, from the user provided mask file
Returns
-------
CoarseMask, CoarseMask
The mask of its input tensor, the mask of its output tensor
"""
assert 'weight' in mask and 'bias' in mask
sum_mask = mask['weight'] + mask['bias']
nonzero_index = torch.nonzero(sum_mask, as_tuple=True)[0]
# infer shape of parameters
param_cmask = CoarseMask(num_dim=1)
param_cmask.add_index_mask(dim=0, index=nonzero_index)
module_masks.set_param_masks('weight', param_cmask)
module_masks.set_param_masks('bias', param_cmask)
# infer shape of input tensor
input_cmask = CoarseMask(num_dim=4)
input_cmask.add_index_mask(dim=1,
index=torch.nonzero(mask['weight'], as_tuple=True)[0])
module_masks.set_input_mask(input_cmask)
# infer shape of output tensor
output_cmask = CoarseMask(num_dim=4)
output_cmask.add_index_mask(dim=1, index=nonzero_index)
module_masks.set_output_mask(output_cmask)
return input_cmask, output_cmask
def conv2d_mask(module_masks, mask):
"""
Infer input and output shape from weight mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the conv2d
mask : dict
The mask of its weights, from the user provided mask file
Returns
-------
CoarseMask, CoarseMask
The mask of its input tensor, the mask of its output tensor
"""
def convert_to_coarse_mask(mask):
"""
Parameters
----------
mask : dict
Weight mask from user provided mask file
Returns
-------
LongTensor, CoarseMask, CoarseMask
Index of the masked dimension, weight mask, bias mask
"""
assert 'weight' in mask
assert isinstance(mask['weight'], torch.Tensor)
weight_mask = mask['weight']
shape = weight_mask.size()
ones = torch.ones(shape[1:]).to(weight_mask.device)
zeros = torch.zeros(shape[1:]).to(weight_mask.device)
index = []
for i in range(shape[0]):
if torch.all(torch.eq(weight_mask[i], ones)):
index.append(i)
elif torch.all(torch.eq(weight_mask[i], zeros)):
continue
else:
index = None
break
if index is None:
return None, None, None
else:
index = torch.LongTensor(index).to(weight_mask.device)
weight_cmask = CoarseMask(num_dim=4)
weight_cmask.add_index_mask(dim=0, index=index)
bias_cmask = None
if 'bias' in mask and mask['bias'] is not None:
bias_index = torch.nonzero(mask['bias'], as_tuple=True)[0]
assert torch.all(torch.eq(index, bias_index)), \
"bias mask should be consistent with weight mask"
bias_cmask = CoarseMask(num_dim=1)
bias_cmask.add_index_mask(dim=0, index=bias_index)
return index, weight_cmask, bias_cmask
index, weight_cmask, bias_cmask = convert_to_coarse_mask(mask)
if index is None:
# TODO: fine grained mask speedup
return None, None
# deal with coarse grain mask
if 'weight' in module_masks.param_masks:
module_masks.param_masks['weight'].merge(weight_cmask)
module_masks.param_masks['bias'].merge(bias_cmask)
else:
module_masks.set_param_masks('weight', weight_cmask)
module_masks.set_param_masks('bias', bias_cmask)
output_cmask = CoarseMask(num_dim=4)
output_cmask.add_index_mask(dim=1, index=index)
if module_masks.output_mask is None:
module_masks.set_output_mask(output_cmask)
else:
module_masks.output_mask.merge(output_cmask)
return None, module_masks.output_mask
def conv2d_inshape(module_masks, mask):
"""
Shape change of input tensor does not affect the shape of its output tensor
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the conv2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
assert module_masks.input_mask is None
module_masks.set_input_mask(mask)
return None
def conv2d_outshape(module_masks, mask):
"""
Assume only the second dimension is masked
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the conv2d
mask : CoarseMask
The mask of its output tensor
Returns
-------
CoarseMask
The mask of its input tensor
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
if module_masks.output_mask is not None:
assert isinstance(module_masks.output_mask, CoarseMask)
# set shape of output
mask = module_masks.output_mask.merge(mask)
else:
module_masks.output_mask = mask
# infer shape of parameters
weight_cmask = CoarseMask(num_dim=4)
weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
bias_cmask = CoarseMask(num_dim=1)
bias_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
module_masks.set_param_masks('weight', weight_cmask)
module_masks.set_param_masks('bias', bias_cmask)
# input shape is not changed
return None
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment