Unverified Commit 97d067e6 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Speedup enhancement (#4925)

parent 4ab85d3d
......@@ -45,6 +45,7 @@ replace_module = {
'Upsample': lambda module, masks: no_replace(module, masks),
'LayerNorm': lambda module, masks: replace_layernorm(module, masks),
'ConvTranspose2d': lambda module, masks: replace_convtranspose2d(module, masks),
'Embedding': lambda module, masks: replace_embedding(module, masks),
'PixelShuffle': lambda module, masks: replace_pixelshuffle(module, masks),
'Flatten': lambda module, masks: no_replace(module, masks)
}
......@@ -85,6 +86,30 @@ def convert_to_coarse_mask(t_mask, dim):
return indexes, remained_indexes
def convert_dense_shape(mask):
"""
Get the dense shape of the tensor after removing the sparsity
values.
Parameters
----------
mask: torch.Tensor
The mask tensor.
Returns
-------
dense_shape: tuple
The dense shape after removing the sparsity values.
"""
assert isinstance(mask, torch.Tensor)
n_dim = len(mask.size())
dense_shape = []
for dim in range(n_dim):
_, remained = convert_to_coarse_mask(mask, dim)
dense_shape.append(remained.size(0))
return tuple(dense_shape)
def no_replace(module, masks):
"""
No need to replace
......@@ -165,9 +190,12 @@ def replace_linear(linear, masks):
in_mask = in_masks[0]
weight_mask = weight_mask['weight']
# the input of the linear may have two dimensions(CV models) or three
# dimensions(Bert, for example)
n_dim = len(in_mask.size())
# N C K
pruned_in, remained_in = convert_to_coarse_mask(in_mask, 1)
pruned_out, remained_out = convert_to_coarse_mask(output_mask, 1)
pruned_in, remained_in = convert_to_coarse_mask(in_mask, n_dim-1)
pruned_out, remained_out = convert_to_coarse_mask(output_mask, n_dim-1)
n_remained_in = weight_mask.size(1) - pruned_in.size(0)
n_remained_out = weight_mask.size(0) - pruned_out.size(0)
remained_in, remained_out = remained_in.to(
......@@ -582,16 +610,20 @@ def replace_layernorm(layernorm, masks):
if len(in_masks) != 1:
raise InputsNumberError()
in_mask = in_masks[0]
dim_n = len(in_mask.size())
new_shape = []
for i in range(1, dim_n):
sum_dims = list(range(0, dim_n))
sum_dims.remove(i)
reduced = torch.sum(in_mask, sum_dims)
n_remained = torch.sum(reduced > 0)
new_shape.append(n_remained)
return nn.LayerNorm(tuple(new_shape), layernorm.eps, layernorm.elementwise_affine)
dense_shape = convert_dense_shape(in_mask)
norm_shape = layernorm.normalized_shape
dim_n = len(dense_shape) - len(norm_shape)
return nn.LayerNorm(dense_shape[dim_n:], layernorm.eps, layernorm.elementwise_affine)
def replace_embedding(embedding, masks):
"""
Replace the embedding layer according the infered masks.
We replace the embedding layer according the weight masks,
"""
# currently we donnot support replace the embedding layer
# because we donnot have the corressponding pruner
return embedding
def replace_pixelshuffle(pixelshuffle, masks):
......
......@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import copy
import logging
from pathlib import Path
import queue
......@@ -66,6 +67,7 @@ class ModelSpeedup:
self.bound_model = model
self.inferred_masks = dict() # key: module_name, value: ModuleMasks
self.batch_dim = batch_dim
self.confidence = confidence
self.dummy_input, self.device = self._random_model_input(
dummy_input, confidence, batch_dim)
self.torch_graph = build_module_graph(model, self.dummy_input)
......@@ -196,6 +198,7 @@ class ModelSpeedup:
# The detach operation here is for the in-place operation. We cannot
# directly can the backward on the output tensor of an in-place operator.
dummy_input.append(self.internal_result[_input].detach())
debugnames.append(_input)
return dummy_input, debugnames
......@@ -229,15 +232,15 @@ class ModelSpeedup:
return
# function doesn't have weights
_auto_infer = AutoMaskInference(
func, dummy_input, in_masks, in_constants=in_constants, batch_dim=self.batch_dim)
func, dummy_input, self, in_masks, in_constants=in_constants)
else:
weight_mask = None
if module_name in self.masks:
weight_mask = self.masks[module_name]
_, module = get_module_by_name(self.bound_model, module_name)
_auto_infer = AutoMaskInference(
module, dummy_input, in_masks, weight_mask, in_constants=in_constants,
state_dict=copy.deepcopy(module.state_dict()), batch_dim=self.batch_dim)
module, dummy_input, self, in_masks, weight_mask, in_constants=in_constants,
state_dict=copy.deepcopy(module.state_dict()))
self.auto_inferences[unique_name] = _auto_infer
_auto_infer.name = node.unique_name
......@@ -280,6 +283,7 @@ class ModelSpeedup:
The target node to update the indirect sparsity
"""
unique_name = node.unique_name
if unique_name in self.auto_inferences and self.auto_inferences[unique_name] is not None:
# if the auto inference object already in self.auto_inference, then
# directly update the previous one
......@@ -291,13 +295,18 @@ class ModelSpeedup:
# pass the gradient to the predecessor nodes
for in_id, tin in enumerate(auto_infer.dummy_input):
debug_name = auto_infer.input_debugname[in_id]
last_output = self.internal_result[debug_name]
# if isinstance(last_output, torch.Tensor):
# TODO what if last output is tuple/list of tensor
if last_output.grad is not None and tin.grad is not None:
last_output.grad.data += tin.grad.data
else:
elif last_output.grad is None:
last_output.grad = tin.grad
elif last_output.grad is not None and tin.grad is None:
# for example, tin.view(batch, tin.size(1)/2, tin.view(2)*2)
# the size operation of tin will have no gradient
continue
else:
_logger.warning(
'Note: %s does not have corresponding mask inference object', node.name)
......@@ -388,6 +397,7 @@ class ModelSpeedup:
if out_degree[predecessor] == 0:
visit_queue.put(self.torch_graph.name_to_node[predecessor])
def replace_compressed_modules(self):
"""
Replace all the modules that have changed (weights/inputs/output) shape.
......@@ -401,6 +411,7 @@ class ModelSpeedup:
for unique_name in self.auto_inferences:
self.replace_submodule(unique_name)
def replace_submodule(self, unique_name, reindex_dim=None, reindex=None):
"""
Replace the submodule according to the inferred sparsity.
......@@ -443,7 +454,6 @@ class ModelSpeedup:
requires_grad=tmpout.requires_grad)
out[self.t_index] = tmpout
return out
assert unique_name in self.auto_inferences
g_node = self.torch_graph.name_to_node[unique_name]
_logger.debug("replace %s, in %s type, with op_type %s",
......@@ -483,12 +493,9 @@ class ModelSpeedup:
setattr(super_module, g_node.name.split(
'.')[-1], new_submodule)
return new_submodule
elif g_node.type == 'func':
_logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type",
unique_name, g_node.op_type)
return None
else:
raise RuntimeError("Unsupported node type: {}".format(g_node.type))
return None
def initialize_speedup(self):
"""
......
......@@ -12,8 +12,8 @@ STD_DELTA = 1e-6
class AutoMaskInference:
def __init__(self, module, dummy_input, in_masks=None, weight_mask=None, \
output_mask=None, name=None, in_constants=None, state_dict=None, batch_dim=0):
def __init__(self, module, dummy_input, speedup, in_masks=None, weight_mask=None,
output_mask=None, name=None, in_constants=None, state_dict=None):
"""
This class will infer the mask of the target module automatically.
This update_direct_sparsity will infer the output mask according
......@@ -28,6 +28,8 @@ class AutoMaskInference:
The target module to infer the mask. Need to be callable.
dummy_input: torch.Tensor/list of Tensor
The dummy_input of the target module.
speedup: ModelSpeedup
The reference of the ModelSpeedup object.
in_masks: list of torch.Tensor
The input masks of the target module, if in_masks is not None, then
update_direct_sparsity and update_indirect_sparsity will incrementally
......@@ -47,8 +49,6 @@ class AutoMaskInference:
The correponding constant values of the in_masks.
state_dict: dict of torch.Tensor
The original values of the weights.
batch_dim: int
The index of the batch dimension of the input tensors.
"""
errmsg = '%s is not callable, should pass the nn.Module/function' % str(
......@@ -112,7 +112,8 @@ class AutoMaskInference:
self.weight_mask[name] = torch.ones_like(para.data)
self.state_dict = state_dict
# TODO support the other batch dimension in the future
self.batch_dim = batch_dim
self.batch_dim = speedup.batch_dim
self.batch_size = speedup.confidence
def random_init(self, start=0.1, end=8.0):
"""
......@@ -125,13 +126,17 @@ class AutoMaskInference:
# rules for ReLU6 to break this range constraint.
with torch.no_grad():
for tensor in self.dummy_input:
if isinstance(tensor, torch.Tensor) and len(tensor.size()) > 0:
# if the tensor is a scalar, then skip this tensor
if isinstance(tensor, torch.Tensor) and len(tensor.size()) > self.batch_dim\
and tensor.size(self.batch_dim) == self.batch_size:
# if the input tensor only has one dimension, which means
# it doesn't have the batch dimension, then we don't randomize
# this tensor, because our tensor scrambling is on the batch
# dimention. For example, if the tensor is a scalar(returned
# by the size operator), then we will skip this tensor
randomize_tensor(tensor, start, end)
for para in self.weights:
randomize_tensor(self.weights[para].data, start, end)
def zero_grad(self):
"""
Set the gradient of the weight, input tensor to be zeros.
......@@ -240,7 +245,6 @@ class AutoMaskInference:
constant[:, mask_pos] = mean[mask_pos]
return out_mask, constant
def update_indirect_sparsity(self):
"""
This function will update the indirect sparsity. To explain what's
......@@ -379,4 +383,3 @@ class AutoMaskInference:
def get_masks(self):
return (self.in_masks, self.output_mask, self.weight_mask)
......@@ -4,11 +4,13 @@
import re
import logging
from functools import partial
import copy
import torch
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
jitid_2_dtype = {4: torch.long, 6:torch.float32}
# to exclude partial
......@@ -243,7 +245,7 @@ def softmax_python(node, speedup):
def contiguous_python(node, speedup):
class contiguousModule(torch.nn.Module):
def forward(self, x):
return x.contiguous()
return x.contiguous().clone()
return contiguousModule()
......@@ -297,6 +299,7 @@ def squeeze_python(node, speedup):
new_squeeze = partial(torch.squeeze, dim=dim)
return new_squeeze
def unsqueeze_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
......@@ -324,7 +327,10 @@ def slice_python(node, speedup):
class SliceMoudle(torch.nn.Module):
def __init__(self, sliceobj):
super(SliceMoudle, self).__init__()
self.sliceobj = sliceobj
# we need to deepcopy the value here, because, in the
# follwing steps, we may randomize the input tensor
# which will change the values of the sliceobj
self.sliceobj = copy.deepcopy(sliceobj)
def forward(self, x, *args):
# args is for the slice dimension and indexes, however,
......@@ -344,20 +350,31 @@ def slice_python(node, speedup):
slice_end = parse_constant(inputs[3], speedup)
slice_step = parse_constant(inputs[4], speedup)
slice_obj = slice(slice_start, slice_end, slice_step)
slice_list = []
for _ in range(slice_dim):
slice_list.append(slice(None, None))
logger.info('Slice dim:%s, Slice obj:%s', str(slice_dim), str(slice_obj))
slice_list.append(slice_obj)
return SliceMoudle(tuple(slice_list))
if inputs[0].debugName() not in speedup.internal_result:
# The inputs of slice operator may be the constant
target_tensor = parse_constant(inputs[0], speedup)
slice_list = tuple(slice_list)
def constant_slice(*args):
return target_tensor[slice_list]
return constant_slice
else:
return SliceMoudle(tuple(slice_list))
def select_python(node, speedup):
class SelectModule(torch.nn.Module):
def __init__(self, dim, index):
super(SelectModule, self).__init__()
self.dim = dim
self.index = index
self.dim = copy.deepcopy(dim)
self.index = copy.deepcopy(index)
def forward(self, x):
return x.select(self.dim, self.index)
......@@ -425,7 +442,9 @@ def permute_python(node, speedup):
class PermuteModule(torch.nn.Module):
def __init__(self, dimlist):
super(PermuteModule, self).__init__()
self.dimlist = dimlist
# deepcopy the values here, because the following randomize operation
# will change the value of the dimlist
self.dimlist = copy.deepcopy(dimlist)
def forward(self, x):
return x.permute(self.dimlist)
......@@ -439,6 +458,7 @@ def getattr_python(node, speedup):
"""
Note: Ops started with Prim:: is not taken as the key node,
so we directly pass the Cpp node into this funciton.
Parameters
----------
node: torch._C.Node
......@@ -462,6 +482,44 @@ def getattr_python(node, speedup):
assert len(key_words) == 1
return GetModule(key_words[0])
def constant_python(node, speedup):
"""
get the constant value of constant operator node.
Parameters
----------
node: torch._C.Node
The cpp node of prim::Getattr
speedup: ModelSpeedup
The corresponding speedup object.
"""
class ConstantModule(torch.nn.Module):
def __init__(self, constant):
super(ConstantModule, self).__init__()
self.constant = constant
def forward(self):
return self.constant
assert node.kind() == 'prim::Constant'
pattern = '\[value=(.*?)\]'
key_words = re.findall(pattern, str(node))
if len(key_words) == 0:
return ConstantModule(None)
assert len(key_words) == 1
# parse the constant value
value = key_words[0]
if value.startswith("\""):
value = torch.device(value[1:-1])
elif value.startswith('{'):
# TODO Support set values in the future
value = set()
elif '.' in value:
# float value
value = float(value)
else:
# integer value
value = int(value)
return ConstantModule(value)
def upsample_bilinear2d_python(node, speedup):
class UpsampleModule(torch.nn.Module):
......@@ -539,16 +597,25 @@ def typeas_python(node, speedup):
def to_python(node, speedup):
# for the time being, only device parameters are supported
class ToModule(torch.nn.Module):
def __init__(self, device):
def __init__(self, device, dtype):
super(ToModule, self).__init__()
self.device = device
self.dtype = dtype
def forward(self, x):
return x.to(device)
return x.to(device, dtype=self.dtype)
c_node = node.key_node
inputs = list(c_node.inputs())
device = inputs[3].toIValue()
return ToModule(device)
in_debugname = inputs[0].debugName()
# device of the input tensor
device = speedup.internal_result[in_debugname].device
for _, _node in enumerate(inputs[1:]):
val = parse_constant(_node, speedup)
if isinstance(val, torch.device):
device = val
dtype = jitid_2_dtype[parse_constant(inputs[1], speedup)]
return ToModule(device, dtype)
def cat_python(node, speedup):
......@@ -566,6 +633,77 @@ def cat_python(node, speedup):
return CatModule(dim)
def ones_python(node, speedup):
class OnesModule(torch.nn.Module):
def __init__(self, out_size, dtype_id, device, require_grad):
super(OnesModule, self).__init__()
self.out_size = out_size
self.device = device
self.require_grad = require_grad
self.dtype = jitid_2_dtype[dtype_id]
def forward(self, *args):
return torch.ones(size=self.out_size, dtype=self.dtype, device=self.device, requires_grad=self.require_grad)
c_node = node.key_node
inputs = list(c_node.inputs())
output_shape = translate_list(inputs[0], speedup)
dtype_id = parse_constant(inputs[1], speedup)
# layout = parse_constant(inputs[2], speedup)
device = parse_constant(inputs[3], speedup)
require_grad = parse_constant(inputs[4], speedup)
return OnesModule(output_shape, dtype_id, device, require_grad)
def zeros_python(node, speedup):
class ZerosModule(torch.nn.Module):
def __init__(self, out_size, dtype_id, device, require_grad):
super(ZerosModule, self).__init__()
self.out_size = out_size
self.device = device
self.require_grad = require_grad
self.dtype = jitid_2_dtype[dtype_id]
def forward(self, *args):
return torch.zeros(size=self.out_size, dtype=self.dtype, device=self.device, requires_grad=self.require_grad)
c_node = node.key_node
inputs = list(c_node.inputs())
output_shape = translate_list(inputs[0], speedup)
dtype_id = parse_constant(inputs[1], speedup)
# layout = parse_constant(inputs[2], speedup)
device = parse_constant(inputs[3], speedup)
require_grad = parse_constant(inputs[4], speedup)
return ZerosModule(output_shape, dtype_id, device, require_grad)
def rsub_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
constant = None
other_name = inputs[1].debugName()
alpha = parse_constant(inputs[2], speedup)
if other_name not in speedup.internal_result:
constant = parse_constant(inputs[1], speedup)
if constant is None:
return torch.sub()
else:
new_sub = partial(torch.sub, other=constant, alpha=alpha)
return new_sub
def expand_python(node, speedup):
class ExpandModule(torch.nn.Module):
def __init__(self, new_size):
super(ExpandModule, self).__init__()
# need deepcopy when the input is size-related
self.new_size = copy.deepcopy(new_size)
def forward(self, *args):
return args[0].expand(self.new_size).clone()
c_node = node.key_node
inputs = list(c_node.inputs())
new_size = translate_list(inputs[1], speedup)
return ExpandModule(new_size)
def expandas_python(node, speedup):
class ExpandasModule(torch.nn.Module):
def forward(self, x, y):
......@@ -616,13 +754,18 @@ trans_from_jit_to_python = {
'aten::exp': exp_python,
'aten::squeeze': squeeze_python,
'aten::unsqueeze': unsqueeze_python,
'aten::constant_pad_nd': constant_pad_nd_python,
'aten::silu': silu_python,
'aten::expand_as': expandas_python,
'aten::ones': ones_python,
'aten::zeros': zeros_python,
'aten::rsub': rsub_python,
'aten::expand': expand_python,
'prim::TupleUnpack': tupleunpack_python,
'prim::ListUnpack': tupleunpack_python,
'prim::NumToTensor': num2tensor_python,
'prim::GetAttr': getattr_python
'prim::GetAttr': getattr_python,
'prim::Constant': constant_python,
'aten::constant_pad_nd': constant_pad_nd_python,
'aten::silu': silu_python,
'aten::expand_as': expandas_python
}
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import re
import logging
import torch
torch_float_dtype = [torch.float, torch.float16, torch.float32, torch.float64, torch.half, torch.double]
torch_integer_dtype = [torch.uint8, torch.int16, torch.short, torch.int32, torch.long, torch.bool]
torch_float_dtype = [torch.float, torch.float16,
torch.float32, torch.float64, torch.half, torch.double]
torch_integer_dtype = [torch.uint8, torch.int16,
torch.short, torch.int32, torch.long, torch.bool]
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
def get_module_by_name(model, module_name):
"""
......@@ -46,11 +54,13 @@ def rand_like_with_shape(shape, ori_t):
require_grad = ori_t.requires_grad
lower_bound = torch.min(ori_t)
higher_bound = torch.max(ori_t)
if dtype in [torch.uint8, torch.int16, torch.short, torch.int16, torch.long, torch.bool]:
return torch.randint(lower_bound, higher_bound+1, shape, dtype=dtype, device=device)
else:
return torch.rand(shape, dtype=dtype, device=device, requires_grad=require_grad)
def randomize_tensor(tensor, start=1, end=100):
"""
Randomize the target tensor according to the given
......@@ -59,11 +69,79 @@ def randomize_tensor(tensor, start=1, end=100):
assert isinstance(tensor, torch.Tensor)
if tensor.dtype in torch_integer_dtype:
# integer tensor can only be randomized by the torch.randint
# torch.randint(int(start), int(end), tensor.size(), out=tensor.data, dtype=tensor.dtype)
pass
torch.randint(int(start), int(end), tensor.size(),
out=tensor.data, dtype=tensor.dtype)
# pass
else:
# we can use nn.init.uniform_ to randomize this tensor
# Note: the tensor that with integer type cannot be randomize
# with nn.init.uniform_
torch.nn.init.uniform_(tensor.data, start, end)
jit_python_code_replacement = {
'torch.slice': lambda tmpstr: python_slice_replace(tmpstr)
}
def translate_jit_code(code):
pattern = 'torch\.(.*?)\('
func_names = re.findall(pattern, code)
modules = {'torch.': torch, 'torch.nn.functional.': torch.nn.functional,
'torch.Tensor.': torch.Tensor, 'torch._C._nn.': torch._C._nn}
replace = {}
# rebase the namespace to get the runnable python code
for full_name in func_names:
func = re.split('\.', full_name)[-1]
for module_name in modules:
torch_module = modules[module_name]
if hasattr(torch_module, func):
replace['torch.'+full_name] = module_name + func
break
# assert found == True, 'Cannot find the function call %s' % full_name
for key, value in replace.items():
code = code.replace(key, value)
# several function cannot find the coresponding function under the namespace
# torch.Tensor and torch.(for example torch.slice), so we need to handle these
# functions manually
lines = code.split('\n')
for i, line in enumerate(lines):
for fname in jit_python_code_replacement:
if fname in line:
lines[i] = jit_python_code_replacement[fname](line)
code = '\n'.join(lines)
code = 'import torch\nfrom torch import Tensor, tensor\nfrom typing import *\n' + code
with open('nni_jit_tmp_forward.py', 'w') as f:
f.write(code)
from nni_jit_tmp_forward import forward # pylint: disable=import-error
return forward
def python_slice_replace(funcstr):
"""
translate the torch.slice to the appropriate python str that can be replace
in the forward function string.
Parameters
----------
funcstr: str
the str that calling the torch.slice, for example:
_8 = torch.slice(attention_mask, 0, 0, 9223372036854775807, 1)
Returns:
new_str: str
the string that should replace the original one
"""
# parse the input parameters
pattern = 'torch\.slice\((.*)\)'
parameter_str = re.findall(pattern, funcstr)
parameters = re.split(',', parameter_str[0])
target_tensor = parameters[0]
dim = int(parameters[1])
dim_str = ','.join([':']*(dim) + [':'.join(parameters[2:])])
print('%s[%s]' % (target_tensor, dim_str))
new_str = funcstr.replace(
'torch.slice(%s)' % parameter_str[0], '%s[%s]' % (target_tensor, dim_str))
return new_str
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment