Unverified Commit 588f299b authored by Louis-J's avatar Louis-J Committed by GitHub
Browse files

feat(speedup): automatically convert op asts to callables (#4996)

parent b2c31ca2
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Type, Union
from typing import TYPE_CHECKING
if TYPE_CHECKING: # Only imports the below statements during type checking
from nni.compression.pytorch.speedup import ModelSpeedup
from nni.common.graph_utils import NodePyGroup
import re import re
import logging import logging
from functools import partial from functools import partial, lru_cache
import copy import copy
import torch import torch
...@@ -15,31 +24,24 @@ jitid_2_dtype = {4: torch.long, 6:torch.float32} ...@@ -15,31 +24,24 @@ jitid_2_dtype = {4: torch.long, 6:torch.float32}
# to exclude partial # to exclude partial
__all__ = [ __all__ = [
'adaptive_avgpool_python', 'add_python', 'avgpool2d_python', 'cat_python', 'contiguous_python', 'getattr_python', 'jit_to_python_function', 'num2tensor_python', 'parse_constant', 'slice_python',
'div_python', 'dropout_python', 'exp_python', 'flatten_python', 'floor_div_python', 'gelu_python', 'translate_list', 'tupleunpack_python', 'dtype_trans', 'memory_format_trans'
'getattr_python', 'jit_to_python_function', 'matmul_python', 'mean_python',
'mul_python', 'num2tensor_python', 'parse_constant', 'permute_python', 'relu_inplace_python',
'relu_python', 'reshape_python', 'select_python', 'sigmoid_python', 'size_python', 'slice_python',
'softmax_python', 'squeeze_python', 'to_python', 'toint_python', 'torch', 'trans_from_jit_to_python',
'translate_list', 'transpose2_python', 'transpose_python', 'tupleunpack_python', 'typeas_python',
'unsqueeze_python', 'upsample_bilinear2d_python', 'view_python'
] ]
def translate_list(list_node: torch._C.Value, speedup: ModelSpeedup=None) -> List:
def translate_list(list_node, speedup=None):
""" """
Get the list of values from the list construct node. Get the list of values from the list construct node.
Parameters Parameters
---------- ----------
list_node: Torch.C.Value list_node
The cpp node of the target list. The cpp node of the target list.
speedup: ModuleSpeed speedup
The Module speedup module. The Module speedup module.
Returns Returns
------- -------
values: list values
The list of values in the target cpp list node. The list of values in the target cpp list node.
""" """
# the node that create the list # the node that create the list
...@@ -52,27 +54,26 @@ def translate_list(list_node, speedup=None): ...@@ -52,27 +54,26 @@ def translate_list(list_node, speedup=None):
if speedup is not None and debugName in speedup.internal_result: if speedup is not None and debugName in speedup.internal_result:
# this value is the result of the other nodes, such as # this value is the result of the other nodes, such as
# ate::size # ate::size
values.append(speedup.internal_result[debugName].item()) values.append(speedup.internal_result[debugName])
else: else:
# if the corresponding value is a constant # if the corresponding value is a constant
values.append(_i.toIValue()) values.append(_i.toIValue())
return values return values
def parse_constant(cvalue: torch._C.Value, speedup: ModelSpeedup) -> Any:
def parse_constant(cvalue, speedup):
""" """
Parse the constant values from this Node Parse the constant values from this Node
Parameters Parameters
---------- ----------
cvalue: Torch.C.Value cvalue
The cpp node of the target constant value. The cpp node of the target constant value.
speedup: ModelSpeedup speedup
The Model speedup module. The Model speedup module.
Returns Returns
------- -------
value: int/float/tensor value
The constant values parsed from the node. The constant values parsed from the node.
""" """
logger.debug('Try to parse the constant value: %s', cvalue.debugName()) logger.debug('Try to parse the constant value: %s', cvalue.debugName())
...@@ -85,245 +86,13 @@ def parse_constant(cvalue, speedup): ...@@ -85,245 +86,13 @@ def parse_constant(cvalue, speedup):
inputs = op_node.inputs() inputs = op_node.inputs()
input_values = [parse_constant(_i, speedup) for _i in inputs] input_values = [parse_constant(_i, speedup) for _i in inputs]
func = trans_from_jit_to_python[op_node.kind()](op_node, speedup) if op_node.kind() not in trans_func_dict:
return func(*input_values) raise RuntimeError('Unsupported function op node type: {}'.format(op_node.kind()))
def dropout_python(node, speedup):
return torch.dropout
def flatten_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
start_dim = inputs[1].toIValue()
end_dim = inputs[2].toIValue()
new_flatten = partial(torch.flatten, start_dim=start_dim, end_dim=end_dim)
return new_flatten
def relu_inplace_python(node, speedup):
return torch.relu_
def relu_python(node, speedup):
return torch.relu
def sigmoid_python(node, speedup):
return torch.sigmoid
def mean_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim_list = translate_list(inputs[1], speedup)
keep_dim = inputs[2].toIValue()
new_mean = partial(torch.mean, dim=tuple(dim_list), keepdim=keep_dim)
return new_mean
def add_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
constant = None
for i in range(2):
input_i = inputs[i]
debug_name = input_i.debugName()
if debug_name not in speedup.internal_result:
# this input is a constant value
# TODO: what if this input is a constant tensor
if input_i.toIValue() is not None:
constant = parse_constant(input_i, speedup)
break
if constant is None:
return torch.add
else:
new_add = partial(torch.add, constant)
return new_add
def sub_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
constant = [None, None]
for i in range(2):
input_i = inputs[i]
debug_name = input_i.debugName()
if debug_name not in speedup.internal_result:
# this input is a constant value
# TODO: what if this input is a constant tensor
if input_i.toIValue() is not None:
constant[i] = parse_constant(input_i, speedup)
break
if constant[0] is None and constant[1] is None:
new_sub = torch.sub
elif constant[0] is not None:
new_sub = partial(torch.sub, input=constant)
else:
new_sub = partial(torch.sub, other=constant)
return new_sub
def floor_div_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
divisor = inputs[1]
constant = None
if divisor.debugName() not in speedup.internal_result:
# divisor is a constant value/tensor
constant = parse_constant(divisor, speedup)
if constant is None:
return torch.floor_divide
else:
new_op = partial(torch.floor_divide, other=constant)
return new_op
def mul_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
constant = None
for i in range(2):
input_i = inputs[i]
debug_name = input_i.debugName()
if debug_name not in speedup.internal_result:
constant = parse_constant(input_i, speedup)
# both two inputs cannot be constants at the same time
break
if constant is None:
return torch.mul
else:
new_mul = partial(torch.mul, constant)
return new_mul
def transpose_python(node, speedup):
return torch.t
def transpose2_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim_1 = inputs[1].toIValue()
dim_2 = inputs[2].toIValue()
new_transpose = partial(torch.transpose, dim0=dim_1, dim1=dim_2)
return new_transpose
def matmul_python(node, speedup):
return torch.matmul
def div_python(node, speedup):
# The second input parameter of torch.div can be a
# tensor or a constant, if it is a constant, we need
# to return
c_node = node.key_node
inputs = list(c_node.inputs())
if inputs[1].debugName() in speedup.internal_result:
# the second input parameters is the output of the other
# nodes
return torch.div
else:
other = inputs[1].toIValue()
new_div = partial(torch.div, other=other)
return new_div
def softmax_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim = inputs[1].toIValue()
new_softmax = partial(torch.softmax, dim=dim)
return new_softmax
def contiguous_python(node, speedup):
class contiguousModule(torch.nn.Module):
def forward(self, x):
return x.contiguous().clone()
return contiguousModule()
def gelu_python(node, speedup):
return torch.nn.GELU()
def silu_python(node, speedup):
return torch.nn.SiLU()
def avgpool2d_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
kernel_size = translate_list(inputs[1], speedup)
stride = translate_list(inputs[2], speedup)
padding = translate_list(inputs[3], speedup)
new_avgpool = partial(torch.nn.functional.avg_pool2d,
kernel_size=kernel_size, stride=stride, padding=padding)
return new_avgpool
def adaptive_avgpool_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
output_size = translate_list(inputs[1], speedup)
new_avgpool = torch.nn.AdaptiveAvgPool2d(output_size)
return new_avgpool
def tupleunpack_python(node, speedup):
# Note: tuple unpack should only exists at the
# the end of the model, and is no need to replace/propagate mask
return None
def num2tensor_python(node, speedup):
return torch.nn.Identity()
def exp_python(node, speedup):
return torch.exp
def squeeze_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim = None
if len(inputs) > 1:
dim = parse_constant(inputs[1], speedup)
new_squeeze = partial(torch.squeeze, dim=dim)
return new_squeeze
def unsqueeze_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim = parse_constant(inputs[1], speedup)
new_unsqueeze = partial(torch.unsqueeze, dim=dim)
return new_unsqueeze
def constant_pad_nd_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
pad = translate_list(inputs[1], speedup)
value = parse_constant(inputs[2], speedup)
new_constant_pad_nd = partial(torch.nn.functional.pad, pad=pad, value=value)
return new_constant_pad_nd
##########################################################
# Split Line
# Following module/functions cannot be translated into a
# single function, so we use torch.nn.Module to wrap the
# the core function, and return the torch.nn.Module instead
##########################################################
func = trans_func_dict[op_node.kind()](op_node, speedup)
return func(*input_values)
def slice_python(node, speedup): def slice_python(node: NodePyGroup, speedup: ModelSpeedup):
class SliceMoudle(torch.nn.Module): class SliceMoudle(torch.nn.Module):
def __init__(self, sliceobj): def __init__(self, sliceobj):
super(SliceMoudle, self).__init__() super(SliceMoudle, self).__init__()
...@@ -368,102 +137,38 @@ def slice_python(node, speedup): ...@@ -368,102 +137,38 @@ def slice_python(node, speedup):
else: else:
return SliceMoudle(tuple(slice_list)) return SliceMoudle(tuple(slice_list))
def cat_python(node: NodePyGroup, speedup: ModelSpeedup):
def select_python(node, speedup): class CatModule(torch.nn.Module):
class SelectModule(torch.nn.Module): def __init__(self, cat_dim):
def __init__(self, dim, index): super(CatModule, self).__init__()
super(SelectModule, self).__init__() self.cat_dim = cat_dim
self.dim = copy.deepcopy(dim)
self.index = copy.deepcopy(index)
def forward(self, x):
return x.select(self.dim, self.index)
c_node = node.key_node
inputs = list(c_node.inputs())
dim = inputs[1].toIValue()
index = inputs[2].toIValue()
return SelectModule(dim, index)
def size_python(node, speedup):
# return None
class SizeMoudle(torch.nn.Module):
def __init__(self, sizedim):
super(SizeMoudle, self).__init__()
self.sizedim = sizedim
def forward(self, x):
return torch.as_tensor([x.size(self.sizedim)], dtype=torch.long)
# return torch.tensor(x.size(self.sizedim))
c_node = node.key_node
inputs = list(c_node.inputs())
size_dim = inputs[1].toIValue()
return SizeMoudle(size_dim)
def toint_python(node, speedup):
class ToIntModule(torch.nn.Module):
def forward(self, x):
return x.to(torch.int)
return ToIntModule()
def view_python(node, speedup):
class ViewModule(torch.nn.Module):
def __init__(self, shape):
super(ViewModule, self).__init__()
self.shape = shape
logger.info('View Module output size: %s', str(self.shape))
def forward(self, *args): def forward(self, *args):
return args[0].view(self.shape) return torch.cat(args, dim=self.cat_dim)
c_node = node.key_node
inputs = list(c_node.inputs())
shape = translate_list(inputs[1], speedup)
return ViewModule(shape)
def reshape_python(node, speedup):
class ReshapeModule(torch.nn.Module):
def __init__(self, shape):
super(ReshapeModule, self).__init__()
self.shape = shape
logger.info('Reshape Module output size: %s', str(self.shape))
def forward(self, *args):
return args[0].reshape(self.shape)
c_node = node.key_node c_node = node.key_node
inputs = list(c_node.inputs()) inputs = list(c_node.inputs())
shape = translate_list(inputs[1], speedup) dim = inputs[1].toIValue()
return ReshapeModule(shape) return CatModule(dim)
def permute_python(node, speedup):
class PermuteModule(torch.nn.Module):
def __init__(self, dimlist):
super(PermuteModule, self).__init__()
# deepcopy the values here, because the following randomize operation
# will change the value of the dimlist
self.dimlist = copy.deepcopy(dimlist)
def forward(self, x): def tupleunpack_python(_node: NodePyGroup, _speedup: ModelSpeedup) -> Optional[Callable]:
return x.permute(self.dimlist) # Note: tuple unpack should only exists at the
c_node = node.key_node # the end of the model, and is no need to replace/propagate mask
inputs = list(c_node.inputs()) return None
dim_list = translate_list(inputs[1], speedup)
return PermuteModule(dim_list)
def num2tensor_python(_node: NodePyGroup, _speedup: ModelSpeedup):
return torch.nn.Identity()
def getattr_python(node, speedup): def getattr_python(node: NodePyGroup, _speedup: ModelSpeedup):
""" """
Note: Ops started with Prim:: is not taken as the key node, Note: Ops started with Prim:: is not taken as the key node,
so we directly pass the Cpp node into this funciton. so we directly pass the Cpp node into this funciton.
Parameters Parameters
---------- ----------
node: torch._C.Node node
The cpp node of prim::Getattr The cpp node of prim::Getattr
speedup: ModelSpeedup speedup
The corresponding speedup object. The corresponding speedup object.
""" """
class GetModule(torch.nn.Module): class GetModule(torch.nn.Module):
...@@ -482,316 +187,332 @@ def getattr_python(node, speedup): ...@@ -482,316 +187,332 @@ def getattr_python(node, speedup):
assert len(key_words) == 1 assert len(key_words) == 1
return GetModule(key_words[0]) return GetModule(key_words[0])
def constant_python(node, speedup): class FuncAdapter:
""" """
get the constant value of constant operator node. A function adapter which can reorder arguments.
It can be initialate with constant argument, and positions of each non-constant
argument. When called, it can put arguments into correct position, then call the
function.
Parameters Attributes
---------- ----------
node: torch._C.Node func
The cpp node of prim::Getattr The function or method to be called.
speedup: ModelSpeedup positional
The corresponding speedup object. Positional arguments values. The placeholder is None if it's non-constant.
""" keyword
class ConstantModule(torch.nn.Module): Keyword arguments values. The placeholder is None if it's non-constant.
def __init__(self, constant): undetermined
super(ConstantModule, self).__init__() A list of the right positions of arguments.
self.constant = constant Position is an int in positional or a str in keyword.
def forward(self): special_treat
return self.constant A Dict of the positions and methods.
The values of these positions should be treat by those methods.
assert node.kind() == 'prim::Constant'
pattern = '\[value=(.*?)\]'
key_words = re.findall(pattern, str(node))
if len(key_words) == 0:
return ConstantModule(None)
assert len(key_words) == 1
# parse the constant value
value = key_words[0]
if value.startswith("\""):
value = torch.device(value[1:-1])
elif value.startswith('{'):
# TODO Support set values in the future
value = set()
elif '.' in value:
# float value
value = float(value)
else:
# integer value
value = int(value)
return ConstantModule(value)
def upsample_bilinear2d_python(node, speedup):
class UpsampleModule(torch.nn.Module):
def __init__(self, size_list, scale_list):
super(UpsampleModule, self).__init__()
self.size_list = size_list
self.scale_list = scale_list
def forward(self, *args):
"""
The first input of args is the target tensor to upsample
, the following parameters is useless, because we already
get the size_list and the scale_list by parsing the cpp_nodes.
"""
return torch.nn.functional.upsample_bilinear(args[0],
size=self.size_list, scale_factor=self.scale_list)
c_node = node.key_node
inputs = list(c_node.inputs())
size_list_node = inputs[1].node()
scale_list_node = inputs[3].node()
size_list = None
scale_list = None
if size_list_node.kind() == 'prim::ListConstruct': """
size_list = translate_list(inputs[1], speedup)
if scale_list_node.kind() == 'prim::ListConstruct':
scale_list = translate_list(inputs[3], speedup)
return UpsampleModule(size_list, scale_list)
def __init__(self, func: Callable, positional: List[Any], keyword: Dict[str, Any],
undetermined: List[Union[int, str]], special_treat: Dict[Union[int, str], Callable]):
if not callable(func):
raise TypeError('the "func" argument must be callable')
self.func = func
self.positional = positional
self.keyword = keyword
self.undetermined = undetermined
self.special_treat = special_treat
def __call__(self, /, *args):
assert len(args) >= len(self.undetermined)
if len(args) > len(self.undetermined):
logger.warning('throw some args away when calling the function "%s"', self.func.__name__)
for i, p in enumerate(self.undetermined):
v = args[i]
if isinstance(p, int):
self.positional[p] = v
else:
self.keyword[p] = v
for p, fs in self.special_treat.items():
if isinstance(p, int):
for f in fs:
self.positional[p] = f(self.positional[p])
else:
for f in fs:
self.keyword[p] = f(self.keyword[p])
result = self.func(*self.positional, **self.keyword)
if isinstance(result, int): # turn result of 'size' into tensor
result = torch.as_tensor([result], dtype=torch.long)
return result
# There are some types that will be convert into enums after jit.
# So we should recover them back:
# device, dtype, layout, memory_format, qscheme, qengine, dispatchkey
enum_to_dtype_names = {
0: 'uint8',
1: 'int8',
2: 'int16',
3: 'int32',
4: 'int64',
5: 'float16',
6: 'float32',
7: 'float64',
8: 'complex32',
9: 'complex64',
10: 'complex128',
11: 'bool',
12: 'qint8',
13: 'quint8',
14: 'qint32',
15: 'bfloat16',
16: 'quint4x2',
17: 'quint2x4',
}
def upsample_nearest2d_python(node, speedup): enum_to_dtype_dict = {}
class UpsampleModule(torch.nn.Module):
def __init__(self, size_list, scale_list):
super(UpsampleModule, self).__init__()
self.size_list = size_list
self.scale_list = scale_list
def forward(self, *args): for enum_value, dtype_name in enum_to_dtype_names.items():
""" if hasattr(torch, dtype_name):
The first input of args is the target tensor to upsample enum_to_dtype_dict[enum_value] = getattr(torch, dtype_name)
, the following parameters is useless, because we already
get the size_list and the scale_list by parsing the cpp_nodes.
"""
return torch.nn.functional.upsample_nearest(args[0],
size=self.size_list, scale_factor=self.scale_list)
c_node = node.key_node
inputs = list(c_node.inputs())
size_list_node = inputs[1].node()
scale_list_node = inputs[2].node()
size_list = None
scale_list = None
if size_list_node.kind() == 'prim::ListConstruct': def dtype_trans(ivalue: Union[int, torch.dtype]):
size_list = translate_list(inputs[1], speedup) """
if scale_list_node.kind() == 'prim::ListConstruct': Special process for dtype.
scale_list = translate_list(inputs[2], speedup) Torch will transform dtype to an enum in cpp, so the value of dtype we get in jit is an int.
return UpsampleModule(size_list, scale_list) This function is used to recover the int to torch.dtype in python.
Parameters
----------
ivalue
The value of dtype or method to be recovered.
def typeas_python(node, speedup):
"""
currently only support type_as float.
TODO: support more types in the type_as, need to figure out
how to get the scalar type from torch._C.TensorType.
""" """
class TypeasModule(torch.nn.Module): if ivalue is None or isinstance(ivalue, torch.dtype):
def __init__(self, dtype=torch.float): return ivalue
self.example = torch.zeros(1, dtype=dtype) elif isinstance(ivalue, int):
if ivalue in enum_to_dtype_dict:
return enum_to_dtype_dict[ivalue]
raise TypeError('No torch.dtype corresponding to the value "%s"', ivalue)
enum_to_memory_format_dict = {
0: torch.contiguous_format,
1: torch.preserve_format,
2: torch.channels_last,
3: torch.channels_last_3d,
}
def forward(self, x): def memory_format_trans(ivalue: Union[int, torch.memory_format]):
return x.type_as(self.example) """
return TypeasModule() Special process for memory_format.
Torch will transform memory_format to an enum in cpp, so the value of memory_format we get in jit is an int.
This function is used to recover the int to torch.memory_format in python.
Parameters
----------
ivalue
The value of memory_format or method to be recovered.
def to_python(node, speedup): """
# for the time being, only device parameters are supported if ivalue is None or isinstance(ivalue, torch.memory_format):
class ToModule(torch.nn.Module): return ivalue
def __init__(self, device, dtype): elif isinstance(ivalue, int):
super(ToModule, self).__init__() global enum_to_memory_format_dict
self.device = device if ivalue in enum_to_memory_format_dict:
self.dtype = dtype return enum_to_memory_format_dict[ivalue]
def forward(self, x): raise TypeError('No torch.memory_format corresponding to the value "%s"', ivalue)
return x.to(device, dtype=self.dtype)
special_treat_dict = {
'dtype': dtype_trans,
'memory_format': memory_format_trans,
}
c_node = node.key_node schema_fix_dict = {
inputs = list(c_node.inputs()) # functinon 'to', 'randint', and 'sparse_coo_tensor' has different schema between python and c++.
in_debugname = inputs[0].debugName() # https://pytorch.org/docs/stable/jit_unsupported.html#ops-with-divergent-schemas-between-torch-python
# device of the input tensor """aten::to.device(Tensor(a) self, Device device, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Ten
device = speedup.internal_result[in_debugname].device sor(a))""":
"""aten::to.device(Tensor(a) self, Device device, int dtype, bool non_blocking=False, bool copy=False, *, int? memory_format=None)
-> (Tensor(a))""",
'aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))':
'aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, *, int? memory_format=None) -> (Tensor(a))',
'aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))':
'aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, *, int? memory_format=None) -> (Tensor(a))',
# todo: are the arguments 'pin_memory' and 'requires_grad' related?
# functions in the python have only 'requires_grad' and functions in the aten have only 'pin_memory'
# 'aten::randint(int high, int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)',
# """aten::randint.generator(int high, int[] size, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, boo
# l? pin_memory=None) -> (Tensor)""",
# """aten::randint.low(int low, int high, int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None)
# -> (Tensor)""",
# """aten::randint.low_generator(int low, int high, int[] size, *, Generator? generator, int? dtype=None, int? layout=None, Device? dev
# ice=None, bool? pin_memory=None) -> (Tensor)""",
# """aten::sparse_coo_tensor.size(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=False) -> (Te
# nsor)""",
# """aten::sparse_coo_tensor.indices(Tensor indices, Tensor values, *, int? dtype=None, int? layout=None, Device? device=None, bool? pi
# n_memory=None) -> (Tensor)""",
# """aten::sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, int? dtype=None, int? layout=None, Device? devi
# ce=None, bool? pin_memory=None) -> (Tensor"""'
}
@lru_cache(maxsize=256)
def parse_aten_schema(schema: str):
"""
Parse the schema, to positional_num and keyword_list, and detect if the argument should be specially treated.
"""
if schema in schema_fix_dict:
schema = schema_fix_dict[schema]
for _, _node in enumerate(inputs[1:]): positional_num = 0
val = parse_constant(_node, speedup) keyword_list = list()
if isinstance(val, torch.device): special_treat = dict() # for dtype and memory_format trans now
device = val
dtype = jitid_2_dtype[parse_constant(inputs[1], speedup)]
return ToModule(device, dtype)
for arg in torch._C.parse_schema(schema).arguments:
if not arg.kwarg_only:
key = positional_num
positional_num += 1
else:
key = arg.name
keyword_list.append(key)
def cat_python(node, speedup): if arg.name in special_treat_dict:
class CatModule(torch.nn.Module): if key not in special_treat:
def __init__(self, cat_dim): special_treat[key] = [special_treat_dict[arg.name]]
super(CatModule, self).__init__() else:
self.cat_dim = cat_dim special_treat[key].append(special_treat_dict[arg.name])
def forward(self, *args): return positional_num, keyword_list, special_treat
return torch.cat(args, dim=self.cat_dim)
c_node = node.key_node def parse_input_value(speedup: ModelSpeedup, input_nodes: List[torch._C.Node], positional_num: int, keyword_list: List[str]):
inputs = list(c_node.inputs()) """
dim = inputs[1].toIValue() translate inputs, to constant positional arguments, constant keyword arguments, and undetermined positions
return CatModule(dim) """
positional = list()
keyword = dict()
undetermined = list()
for ainput in input_nodes:
if ainput.node().kind() == 'prim::ListConstruct':
arg = translate_list(ainput, speedup)
elif ainput.node().kind() == 'prim::Constant':
arg = ainput.toIValue()
else:
assert 'aten::' in ainput.node().kind() or 'prim::' in ainput.node().kind()
if len(positional) < positional_num:
undetermined.append(len(positional))
else:
undetermined.append(keyword_list[positional_num - len(positional)])
arg = None
if len(positional) < positional_num:
positional.append(arg)
else:
keyword[keyword_list[positional_num - len(positional)]] = arg
return positional, keyword, undetermined
def special_treat_to_constant_value(positional: List, keyword: Dict[str], undetermined: List[Union[int, str]],
special_treat: Dict[Union[int, str], Callable]):
"""
if any argument with special_treat is not in undetermined, do the treat
"""
undetermined_special_treat = dict()
for p, fs in special_treat.items():
if p in undetermined:
undetermined_special_treat[p] = fs
elif isinstance(p, int):
for f in fs: positional[p] = f(positional[p])
else:
for f in fs: keyword[p] = f(keyword[p])
return undetermined_special_treat
def ones_python(node, speedup): def generate_aten_to_python(func: Callable, node: NodePyGroup, speedup: ModelSpeedup) -> FuncAdapter:
class OnesModule(torch.nn.Module): """
def __init__(self, out_size, dtype_id, device, require_grad): parse a Return a callable object to inference the mask according to the node.op_type.
super(OnesModule, self).__init__()
self.out_size = out_size
self.device = device
self.require_grad = require_grad
self.dtype = jitid_2_dtype[dtype_id]
def forward(self, *args): Parameters
return torch.ones(size=self.out_size, dtype=self.dtype, device=self.device, requires_grad=self.require_grad) ---------
func
The torch function one-to-one correspondence with the node.
node
The target node to inference the mask
speedup
The speedup object of the target model.
Returns
------
func
Return the translated function that used to inference the mask
, if current op_type is not supported, then we return None.
"""
c_node = node.key_node c_node = node.key_node
inputs = list(c_node.inputs())
output_shape = translate_list(inputs[0], speedup)
dtype_id = parse_constant(inputs[1], speedup)
# layout = parse_constant(inputs[2], speedup)
device = parse_constant(inputs[3], speedup)
require_grad = parse_constant(inputs[4], speedup)
return OnesModule(output_shape, dtype_id, device, require_grad)
def zeros_python(node, speedup):
class ZerosModule(torch.nn.Module):
def __init__(self, out_size, dtype_id, device, require_grad):
super(ZerosModule, self).__init__()
self.out_size = out_size
self.device = device
self.require_grad = require_grad
self.dtype = jitid_2_dtype[dtype_id]
def forward(self, *args): schema = c_node.schema()
return torch.zeros(size=self.out_size, dtype=self.dtype, device=self.device, requires_grad=self.require_grad) positional_num, keyword_list, special_treat = parse_aten_schema(schema)
c_node = node.key_node input_nodes = list(c_node.inputs())
inputs = list(c_node.inputs()) positional, keyword, undetermined = parse_input_value(speedup, input_nodes, positional_num, keyword_list)
output_shape = translate_list(inputs[0], speedup)
dtype_id = parse_constant(inputs[1], speedup)
# layout = parse_constant(inputs[2], speedup)
device = parse_constant(inputs[3], speedup)
require_grad = parse_constant(inputs[4], speedup)
return ZerosModule(output_shape, dtype_id, device, require_grad)
def rsub_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
constant = None
other_name = inputs[1].debugName()
alpha = parse_constant(inputs[2], speedup)
if other_name not in speedup.internal_result:
constant = parse_constant(inputs[1], speedup)
if constant is None:
return torch.sub()
else:
new_sub = partial(torch.sub, other=constant, alpha=alpha)
return new_sub
def expand_python(node, speedup): undetermined_special_treat = special_treat_to_constant_value(positional, keyword, undetermined, special_treat)
class ExpandModule(torch.nn.Module):
def __init__(self, new_size):
super(ExpandModule, self).__init__()
# need deepcopy when the input is size-related
self.new_size = copy.deepcopy(new_size)
def forward(self, *args): return FuncAdapter(func, positional, keyword, undetermined, undetermined_special_treat)
return args[0].expand(self.new_size).clone()
c_node = node.key_node trans_func_dict = {
inputs = list(c_node.inputs())
new_size = translate_list(inputs[1], speedup)
return ExpandModule(new_size)
def expandas_python(node, speedup):
class ExpandasModule(torch.nn.Module):
def forward(self, x, y):
return x.expand_as(y).clone()
return ExpandasModule()
trans_from_jit_to_python = {
'aten::add': add_python,
'aten::add_': add_python,
'aten::sub': sub_python,
'aten::sub_': sub_python,
'aten::mul': mul_python,
'aten::mul_': mul_python,
'aten::relu': relu_python,
'aten::relu_': relu_inplace_python,
'aten::sigmoid': sigmoid_python,
'aten::sigmoid_': sigmoid_python,
# tanh behaives like relu
'aten::tanh': relu_python,
'aten::tanh_': relu_python,
'aten::flatten': flatten_python,
'aten::mean': mean_python,
'aten::dropout': dropout_python,
'aten::slice': slice_python, 'aten::slice': slice_python,
'aten::select': select_python,
'aten::size': size_python,
'aten::t': transpose_python,
'aten::transpose': transpose2_python,
'aten::Int': toint_python,
'aten::view': view_python,
'aten::reshape': reshape_python,
'aten::permute': permute_python,
'aten::matmul': matmul_python,
'aten::div': div_python,
'aten::floor_divide': floor_div_python,
'aten::softmax': softmax_python,
'aten::contiguous': contiguous_python,
'aten::gelu': gelu_python,
'aten::cat': cat_python, 'aten::cat': cat_python,
'aten::avg_pool2d': avgpool2d_python, 'aten::Int': partial(generate_aten_to_python, torch._C._TensorBase.int),
'aten::max_pool2d': avgpool2d_python,
'aten::adaptive_avg_pool2d': adaptive_avgpool_python,
'aten::to': to_python,
'aten::type_as': typeas_python,
'aten::upsample_bilinear2d': upsample_bilinear2d_python,
'aten::upsample_nearest2d': upsample_nearest2d_python,
'aten::exp': exp_python,
'aten::squeeze': squeeze_python,
'aten::unsqueeze': unsqueeze_python,
'aten::ones': ones_python,
'aten::zeros': zeros_python,
'aten::rsub': rsub_python,
'aten::expand': expand_python,
'prim::TupleUnpack': tupleunpack_python, 'prim::TupleUnpack': tupleunpack_python,
'prim::ListUnpack': tupleunpack_python, 'prim::ListUnpack': tupleunpack_python,
'prim::NumToTensor': num2tensor_python, 'prim::NumToTensor': num2tensor_python,
'prim::GetAttr': getattr_python, 'prim::GetAttr': getattr_python,
'prim::Constant': constant_python,
'aten::constant_pad_nd': constant_pad_nd_python,
'aten::silu': silu_python,
'aten::expand_as': expandas_python
} }
def init_add_functions(func_from: Union[ModuleType, Type[Any]]):
"""
Add function/method attributes from a module/class, to the trans_func_dict
Parameters
---------
func_from
The module/class include needed functions
def jit_to_python_function(node, speedup):
""" """
Return a callable object to inference the mask according to the global trans_func_dict
node.op_type. new_trans_func_dict = dict()
for name in dir(func_from):
attr = getattr(func_from, name)
if callable(attr) and not name.startswith('__'):
new_trans_func_dict['aten::' + name] = partial(generate_aten_to_python, attr)
trans_func_dict = {**new_trans_func_dict, **trans_func_dict}
init_add_functions(torch._C._VariableFunctions)
init_add_functions(torch._C._nn)
init_add_functions(torch._C._TensorBase)
def jit_to_python_function(node: NodePyGroup, speedup: ModelSpeedup) -> FuncAdapter:
"""
Return a callable object to inference the mask according to the node.op_type.
Parameters Parameters
--------- ---------
node: NodeGroup node
The target node to inference the mask The target node to inference the mask
speedup: ModelSpeedup speedup
The speedup object of the target model. The speedup object of the target model.
Returns Returns
------ ------
func: callable object(nn.Module/function) func
Return the translated function that used to inference the mask Return the translated function that used to inference the mask
, if current op_type is not supported, then we return None. , if current op_type is not supported, then we return None.
""" """
logger.debug( logger.debug(
'Translate C function %s into its python version', node.op_type) 'Translate C function %s into its python version', node.op_type)
if node.op_type not in trans_from_jit_to_python: if node.op_type not in trans_func_dict:
logger.error( logger.error(
'%s is not Supported! Please report an issue at https://github.com/microsoft/nni. Thanks~', node.op_type) '%s is not Supported! Please report an issue at https://github.com/microsoft/nni. Thanks~', node.op_type)
# return None to skip the mask inference for this node # return None to skip the mask inference for this node
return None return None
return trans_from_jit_to_python[node.op_type](node, speedup) return trans_func_dict[node.op_type](node, speedup)
...@@ -61,7 +61,7 @@ class BackboneModel2(torch.nn.Module): ...@@ -61,7 +61,7 @@ class BackboneModel2(torch.nn.Module):
x = F.max_pool2d(x, 2, 2) x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x))) x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2) x = F.max_pool2d(x, 2, 2)
x = x.view(x.size(0), -1) x = x.reshape(x.size(0), -1)
x = F.relu(self.fc1(x)) x = F.relu(self.fc1(x))
x = self.fc2(x) x = self.fc2(x)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import unittest
import torch
import torch.nn.functional as F
from nni.compression.pytorch.pruning import L1NormPruner
from nni.compression.pytorch.speedup import ModelSpeedup
from nni.algorithms.compression.v2.pytorch.utils import (
compute_sparsity_compact2origin,
compute_sparsity_mask2compact
)
class CondModel(torch.nn.Module):
"""
test for:
prim::If
"""
the_cond: bool
def __init__(self):
super().__init__()
self.the_cond = True
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.the_cond:
x = x + 0.00001
else:
x = x - 0.00001
self.the_cond = not self.the_cond
return x
class ASubModel(torch.nn.Module):
"""
test for:
sub model
"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + 0.00001
return x
class TorchModel1(torch.nn.Module):
"""
test for:
add, sub, mul, div, exp, matmul,
relu, gelu, tanh, silu, sigmod, softmax,
size, unsqueeze, flatten, cat, slice, reshape, transpose, t, select, permute, constant_pad_nd,
mean, avg_pool2d, max_pool2d, sum, adaptive_avg_pool2d,
to, Int, view,
type_as, expand_as, contiguous,
notes:
'floor_divide' have no backward, then not be tested
"""
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 6, 5, 1)
self.conv2 = torch.nn.Conv2d(6, 16, 5, 1)
self.fccond = torch.nn.Linear(16 * 4 * 4, 16 * 4 * 4)
self.fc1 = torch.nn.Linear(16 * 4 * 4, 120)
self.fc2 = torch.nn.Linear(120, 84)
self.fc3 = torch.nn.Linear(84, 10)
self.pool1 = torch.nn.MaxPool2d((2, 2))
self.pool2 = torch.nn.MaxPool2d((2, 2))
self.cond = torch.jit.script(CondModel())
self.asub = ASubModel()
def forward(self, x: torch.Tensor):
x = x.contiguous(memory_format=torch.channels_last)
x = torch._C._nn.upsample_bilinear2d(x, (28, 28), False)
x = torch._C._nn.upsample_nearest2d(x, (28, 28))
x = F.adaptive_avg_pool2d(x, (28, 28))
x = torch.exp(x)
x = torch.sigmoid(x)
x = torch.transpose(x, 1, 2)
x = torch.transpose(x, 1, 2)
x = F.avg_pool2d(x, 3, 1, padding=1)
x = F.max_pool2d(x, 3, 1, padding=1)
x = x.to(torch.float32)
x = self.conv1(x)
y1 = self.pool1(F.relu(x))
y2 = self.pool1(F.gelu(x))
x = y1 + y2
x = x + 0.00001
x = x * 1.00001
x = self.conv2(x)
y1 = self.pool2(F.silu(x))
y2 = self.pool2(torch.tanh(x))
x = y1 - y2
x = x - 0.00001
x = x / 1.00001
x = torch.permute(x, (0, 2, 3, 1))
x = torch.permute(x, (0, 2, 3, 1))
x = torch.permute(x, (0, 2, 3, 1))
x = torch.unsqueeze(x, dim=1)
x = torch.select(x, dim=1, index=0)
x = torch.unsqueeze(x, dim=1)
x = torch.mean(x, dim=1)
x = torch.unsqueeze(x, dim=1)
x = torch.sum(x, dim=1, dtype=torch.float32)
x = torch.unsqueeze(x, dim=1)
x = torch.squeeze(x, dim=1)
x = torch.flatten(x, 1)
x = x.reshape(x.shape)
x = x.view(-1, x.size(1))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.softmax(self.fc3(x), dim=1)
y1 = x[:,0:int(x.size(1)/2)]
y2 = x[:,int(x.size(1)/2):x.size(1)]
x = torch.cat((y1, y2), dim=1)
x = x.type_as(x)
x = x.expand_as(x)
x = torch.matmul(x, x.t())
x = torch.cat([x, x], dim=1)
# x = self.cond(x)
x = self.asub(x)
x = torch.constant_pad_nd(x, (1,1,1,1), 3.14159)
return x
class AutoConvTestCase(unittest.TestCase):
def test_l1norm_pruner(self):
model = TorchModel1()
dummy_input = torch.rand(3, 1, 28, 28)
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.5}]
pruner = L1NormPruner(model=model, config_list=config_list)
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
ModelSpeedup(model, dummy_input, masks).speedup_model()
real_sparsity_list = compute_sparsity_compact2origin(TorchModel1(), model, config_list)
print('sparsity_list:', sparsity_list)
assert 0.45 < sparsity_list[0]['total_sparsity'] < 0.55
print('real_sparsity_list:', real_sparsity_list)
assert 0.45 < real_sparsity_list[0]['total_sparsity'] < 0.75
print('the shape of output of the infer:', model(dummy_input).shape)
assert model(dummy_input).shape == torch.Size((5, 8))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment