Unverified Commit 588f299b authored by Louis-J's avatar Louis-J Committed by GitHub
Browse files

feat(speedup): automatically convert op asts to callables (#4996)

parent b2c31ca2
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Type, Union
from typing import TYPE_CHECKING
if TYPE_CHECKING: # Only imports the below statements during type checking
from nni.compression.pytorch.speedup import ModelSpeedup
from nni.common.graph_utils import NodePyGroup
import re import re
import logging import logging
from functools import partial from functools import partial, lru_cache
import copy import copy
import torch import torch
...@@ -15,31 +24,24 @@ jitid_2_dtype = {4: torch.long, 6:torch.float32} ...@@ -15,31 +24,24 @@ jitid_2_dtype = {4: torch.long, 6:torch.float32}
# to exclude partial # to exclude partial
__all__ = [ __all__ = [
'adaptive_avgpool_python', 'add_python', 'avgpool2d_python', 'cat_python', 'contiguous_python', 'getattr_python', 'jit_to_python_function', 'num2tensor_python', 'parse_constant', 'slice_python',
'div_python', 'dropout_python', 'exp_python', 'flatten_python', 'floor_div_python', 'gelu_python', 'translate_list', 'tupleunpack_python', 'dtype_trans', 'memory_format_trans'
'getattr_python', 'jit_to_python_function', 'matmul_python', 'mean_python',
'mul_python', 'num2tensor_python', 'parse_constant', 'permute_python', 'relu_inplace_python',
'relu_python', 'reshape_python', 'select_python', 'sigmoid_python', 'size_python', 'slice_python',
'softmax_python', 'squeeze_python', 'to_python', 'toint_python', 'torch', 'trans_from_jit_to_python',
'translate_list', 'transpose2_python', 'transpose_python', 'tupleunpack_python', 'typeas_python',
'unsqueeze_python', 'upsample_bilinear2d_python', 'view_python'
] ]
def translate_list(list_node: torch._C.Value, speedup: ModelSpeedup=None) -> List:
def translate_list(list_node, speedup=None):
""" """
Get the list of values from the list construct node. Get the list of values from the list construct node.
Parameters Parameters
---------- ----------
list_node: Torch.C.Value list_node
The cpp node of the target list. The cpp node of the target list.
speedup: ModuleSpeed speedup
The Module speedup module. The Module speedup module.
Returns Returns
------- -------
values: list values
The list of values in the target cpp list node. The list of values in the target cpp list node.
""" """
# the node that create the list # the node that create the list
...@@ -52,27 +54,26 @@ def translate_list(list_node, speedup=None): ...@@ -52,27 +54,26 @@ def translate_list(list_node, speedup=None):
if speedup is not None and debugName in speedup.internal_result: if speedup is not None and debugName in speedup.internal_result:
# this value is the result of the other nodes, such as # this value is the result of the other nodes, such as
# ate::size # ate::size
values.append(speedup.internal_result[debugName].item()) values.append(speedup.internal_result[debugName])
else: else:
# if the corresponding value is a constant # if the corresponding value is a constant
values.append(_i.toIValue()) values.append(_i.toIValue())
return values return values
def parse_constant(cvalue: torch._C.Value, speedup: ModelSpeedup) -> Any:
def parse_constant(cvalue, speedup):
""" """
Parse the constant values from this Node Parse the constant values from this Node
Parameters Parameters
---------- ----------
cvalue: Torch.C.Value cvalue
The cpp node of the target constant value. The cpp node of the target constant value.
speedup: ModelSpeedup speedup
The Model speedup module. The Model speedup module.
Returns Returns
------- -------
value: int/float/tensor value
The constant values parsed from the node. The constant values parsed from the node.
""" """
logger.debug('Try to parse the constant value: %s', cvalue.debugName()) logger.debug('Try to parse the constant value: %s', cvalue.debugName())
...@@ -85,245 +86,13 @@ def parse_constant(cvalue, speedup): ...@@ -85,245 +86,13 @@ def parse_constant(cvalue, speedup):
inputs = op_node.inputs() inputs = op_node.inputs()
input_values = [parse_constant(_i, speedup) for _i in inputs] input_values = [parse_constant(_i, speedup) for _i in inputs]
func = trans_from_jit_to_python[op_node.kind()](op_node, speedup) if op_node.kind() not in trans_func_dict:
return func(*input_values) raise RuntimeError('Unsupported function op node type: {}'.format(op_node.kind()))
def dropout_python(node, speedup):
return torch.dropout
def flatten_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
start_dim = inputs[1].toIValue()
end_dim = inputs[2].toIValue()
new_flatten = partial(torch.flatten, start_dim=start_dim, end_dim=end_dim)
return new_flatten
def relu_inplace_python(node, speedup):
return torch.relu_
def relu_python(node, speedup):
return torch.relu
def sigmoid_python(node, speedup):
return torch.sigmoid
def mean_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim_list = translate_list(inputs[1], speedup)
keep_dim = inputs[2].toIValue()
new_mean = partial(torch.mean, dim=tuple(dim_list), keepdim=keep_dim)
return new_mean
def add_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
constant = None
for i in range(2):
input_i = inputs[i]
debug_name = input_i.debugName()
if debug_name not in speedup.internal_result:
# this input is a constant value
# TODO: what if this input is a constant tensor
if input_i.toIValue() is not None:
constant = parse_constant(input_i, speedup)
break
if constant is None:
return torch.add
else:
new_add = partial(torch.add, constant)
return new_add
def sub_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
constant = [None, None]
for i in range(2):
input_i = inputs[i]
debug_name = input_i.debugName()
if debug_name not in speedup.internal_result:
# this input is a constant value
# TODO: what if this input is a constant tensor
if input_i.toIValue() is not None:
constant[i] = parse_constant(input_i, speedup)
break
if constant[0] is None and constant[1] is None:
new_sub = torch.sub
elif constant[0] is not None:
new_sub = partial(torch.sub, input=constant)
else:
new_sub = partial(torch.sub, other=constant)
return new_sub
def floor_div_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
divisor = inputs[1]
constant = None
if divisor.debugName() not in speedup.internal_result:
# divisor is a constant value/tensor
constant = parse_constant(divisor, speedup)
if constant is None:
return torch.floor_divide
else:
new_op = partial(torch.floor_divide, other=constant)
return new_op
def mul_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
constant = None
for i in range(2):
input_i = inputs[i]
debug_name = input_i.debugName()
if debug_name not in speedup.internal_result:
constant = parse_constant(input_i, speedup)
# both two inputs cannot be constants at the same time
break
if constant is None:
return torch.mul
else:
new_mul = partial(torch.mul, constant)
return new_mul
def transpose_python(node, speedup):
return torch.t
def transpose2_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim_1 = inputs[1].toIValue()
dim_2 = inputs[2].toIValue()
new_transpose = partial(torch.transpose, dim0=dim_1, dim1=dim_2)
return new_transpose
def matmul_python(node, speedup):
return torch.matmul
def div_python(node, speedup):
# The second input parameter of torch.div can be a
# tensor or a constant, if it is a constant, we need
# to return
c_node = node.key_node
inputs = list(c_node.inputs())
if inputs[1].debugName() in speedup.internal_result:
# the second input parameters is the output of the other
# nodes
return torch.div
else:
other = inputs[1].toIValue()
new_div = partial(torch.div, other=other)
return new_div
def softmax_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim = inputs[1].toIValue()
new_softmax = partial(torch.softmax, dim=dim)
return new_softmax
def contiguous_python(node, speedup):
class contiguousModule(torch.nn.Module):
def forward(self, x):
return x.contiguous().clone()
return contiguousModule()
def gelu_python(node, speedup):
return torch.nn.GELU()
def silu_python(node, speedup):
return torch.nn.SiLU()
def avgpool2d_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
kernel_size = translate_list(inputs[1], speedup)
stride = translate_list(inputs[2], speedup)
padding = translate_list(inputs[3], speedup)
new_avgpool = partial(torch.nn.functional.avg_pool2d,
kernel_size=kernel_size, stride=stride, padding=padding)
return new_avgpool
def adaptive_avgpool_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
output_size = translate_list(inputs[1], speedup)
new_avgpool = torch.nn.AdaptiveAvgPool2d(output_size)
return new_avgpool
def tupleunpack_python(node, speedup):
# Note: tuple unpack should only exists at the
# the end of the model, and is no need to replace/propagate mask
return None
def num2tensor_python(node, speedup):
return torch.nn.Identity()
def exp_python(node, speedup):
return torch.exp
def squeeze_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim = None
if len(inputs) > 1:
dim = parse_constant(inputs[1], speedup)
new_squeeze = partial(torch.squeeze, dim=dim)
return new_squeeze
def unsqueeze_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim = parse_constant(inputs[1], speedup)
new_unsqueeze = partial(torch.unsqueeze, dim=dim)
return new_unsqueeze
def constant_pad_nd_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
pad = translate_list(inputs[1], speedup)
value = parse_constant(inputs[2], speedup)
new_constant_pad_nd = partial(torch.nn.functional.pad, pad=pad, value=value)
return new_constant_pad_nd
##########################################################
# Split Line
# Following module/functions cannot be translated into a
# single function, so we use torch.nn.Module to wrap the
# the core function, and return the torch.nn.Module instead
##########################################################
func = trans_func_dict[op_node.kind()](op_node, speedup)
return func(*input_values)
def slice_python(node, speedup): def slice_python(node: NodePyGroup, speedup: ModelSpeedup):
class SliceMoudle(torch.nn.Module): class SliceMoudle(torch.nn.Module):
def __init__(self, sliceobj): def __init__(self, sliceobj):
super(SliceMoudle, self).__init__() super(SliceMoudle, self).__init__()
...@@ -368,102 +137,38 @@ def slice_python(node, speedup): ...@@ -368,102 +137,38 @@ def slice_python(node, speedup):
else: else:
return SliceMoudle(tuple(slice_list)) return SliceMoudle(tuple(slice_list))
def cat_python(node: NodePyGroup, speedup: ModelSpeedup):
def select_python(node, speedup): class CatModule(torch.nn.Module):
class SelectModule(torch.nn.Module): def __init__(self, cat_dim):
def __init__(self, dim, index): super(CatModule, self).__init__()
super(SelectModule, self).__init__() self.cat_dim = cat_dim
self.dim = copy.deepcopy(dim)
self.index = copy.deepcopy(index)
def forward(self, x):
return x.select(self.dim, self.index)
c_node = node.key_node
inputs = list(c_node.inputs())
dim = inputs[1].toIValue()
index = inputs[2].toIValue()
return SelectModule(dim, index)
def size_python(node, speedup):
# return None
class SizeMoudle(torch.nn.Module):
def __init__(self, sizedim):
super(SizeMoudle, self).__init__()
self.sizedim = sizedim
def forward(self, x):
return torch.as_tensor([x.size(self.sizedim)], dtype=torch.long)
# return torch.tensor(x.size(self.sizedim))
c_node = node.key_node
inputs = list(c_node.inputs())
size_dim = inputs[1].toIValue()
return SizeMoudle(size_dim)
def toint_python(node, speedup):
class ToIntModule(torch.nn.Module):
def forward(self, x):
return x.to(torch.int)
return ToIntModule()
def view_python(node, speedup):
class ViewModule(torch.nn.Module):
def __init__(self, shape):
super(ViewModule, self).__init__()
self.shape = shape
logger.info('View Module output size: %s', str(self.shape))
def forward(self, *args): def forward(self, *args):
return args[0].view(self.shape) return torch.cat(args, dim=self.cat_dim)
c_node = node.key_node
inputs = list(c_node.inputs())
shape = translate_list(inputs[1], speedup)
return ViewModule(shape)
def reshape_python(node, speedup):
class ReshapeModule(torch.nn.Module):
def __init__(self, shape):
super(ReshapeModule, self).__init__()
self.shape = shape
logger.info('Reshape Module output size: %s', str(self.shape))
def forward(self, *args):
return args[0].reshape(self.shape)
c_node = node.key_node c_node = node.key_node
inputs = list(c_node.inputs()) inputs = list(c_node.inputs())
shape = translate_list(inputs[1], speedup) dim = inputs[1].toIValue()
return ReshapeModule(shape) return CatModule(dim)
def permute_python(node, speedup):
class PermuteModule(torch.nn.Module):
def __init__(self, dimlist):
super(PermuteModule, self).__init__()
# deepcopy the values here, because the following randomize operation
# will change the value of the dimlist
self.dimlist = copy.deepcopy(dimlist)
def forward(self, x): def tupleunpack_python(_node: NodePyGroup, _speedup: ModelSpeedup) -> Optional[Callable]:
return x.permute(self.dimlist) # Note: tuple unpack should only exists at the
c_node = node.key_node # the end of the model, and is no need to replace/propagate mask
inputs = list(c_node.inputs()) return None
dim_list = translate_list(inputs[1], speedup)
return PermuteModule(dim_list)
def num2tensor_python(_node: NodePyGroup, _speedup: ModelSpeedup):
return torch.nn.Identity()
def getattr_python(node, speedup): def getattr_python(node: NodePyGroup, _speedup: ModelSpeedup):
""" """
Note: Ops started with Prim:: is not taken as the key node, Note: Ops started with Prim:: is not taken as the key node,
so we directly pass the Cpp node into this funciton. so we directly pass the Cpp node into this funciton.
Parameters Parameters
---------- ----------
node: torch._C.Node node
The cpp node of prim::Getattr The cpp node of prim::Getattr
speedup: ModelSpeedup speedup
The corresponding speedup object. The corresponding speedup object.
""" """
class GetModule(torch.nn.Module): class GetModule(torch.nn.Module):
...@@ -482,316 +187,332 @@ def getattr_python(node, speedup): ...@@ -482,316 +187,332 @@ def getattr_python(node, speedup):
assert len(key_words) == 1 assert len(key_words) == 1
return GetModule(key_words[0]) return GetModule(key_words[0])
def constant_python(node, speedup): class FuncAdapter:
""" """
get the constant value of constant operator node. A function adapter which can reorder arguments.
It can be initialate with constant argument, and positions of each non-constant
argument. When called, it can put arguments into correct position, then call the
function.
Parameters Attributes
---------- ----------
node: torch._C.Node func
The cpp node of prim::Getattr The function or method to be called.
speedup: ModelSpeedup positional
The corresponding speedup object. Positional arguments values. The placeholder is None if it's non-constant.
keyword
Keyword arguments values. The placeholder is None if it's non-constant.
undetermined
A list of the right positions of arguments.
Position is an int in positional or a str in keyword.
special_treat
A Dict of the positions and methods.
The values of these positions should be treat by those methods.
""" """
class ConstantModule(torch.nn.Module):
def __init__(self, constant): def __init__(self, func: Callable, positional: List[Any], keyword: Dict[str, Any],
super(ConstantModule, self).__init__() undetermined: List[Union[int, str]], special_treat: Dict[Union[int, str], Callable]):
self.constant = constant if not callable(func):
def forward(self): raise TypeError('the "func" argument must be callable')
return self.constant
self.func = func
assert node.kind() == 'prim::Constant' self.positional = positional
pattern = '\[value=(.*?)\]' self.keyword = keyword
key_words = re.findall(pattern, str(node)) self.undetermined = undetermined
if len(key_words) == 0: self.special_treat = special_treat
return ConstantModule(None)
assert len(key_words) == 1 def __call__(self, /, *args):
# parse the constant value assert len(args) >= len(self.undetermined)
value = key_words[0] if len(args) > len(self.undetermined):
if value.startswith("\""): logger.warning('throw some args away when calling the function "%s"', self.func.__name__)
value = torch.device(value[1:-1])
elif value.startswith('{'): for i, p in enumerate(self.undetermined):
# TODO Support set values in the future v = args[i]
value = set() if isinstance(p, int):
elif '.' in value: self.positional[p] = v
# float value
value = float(value)
else: else:
# integer value self.keyword[p] = v
value = int(value)
return ConstantModule(value)
def upsample_bilinear2d_python(node, speedup): for p, fs in self.special_treat.items():
class UpsampleModule(torch.nn.Module): if isinstance(p, int):
def __init__(self, size_list, scale_list): for f in fs:
super(UpsampleModule, self).__init__() self.positional[p] = f(self.positional[p])
self.size_list = size_list else:
self.scale_list = scale_list for f in fs:
self.keyword[p] = f(self.keyword[p])
result = self.func(*self.positional, **self.keyword)
if isinstance(result, int): # turn result of 'size' into tensor
result = torch.as_tensor([result], dtype=torch.long)
return result
# There are some types that will be convert into enums after jit.
# So we should recover them back:
# device, dtype, layout, memory_format, qscheme, qengine, dispatchkey
enum_to_dtype_names = {
0: 'uint8',
1: 'int8',
2: 'int16',
3: 'int32',
4: 'int64',
5: 'float16',
6: 'float32',
7: 'float64',
8: 'complex32',
9: 'complex64',
10: 'complex128',
11: 'bool',
12: 'qint8',
13: 'quint8',
14: 'qint32',
15: 'bfloat16',
16: 'quint4x2',
17: 'quint2x4',
}
def forward(self, *args): enum_to_dtype_dict = {}
"""
The first input of args is the target tensor to upsample
, the following parameters is useless, because we already
get the size_list and the scale_list by parsing the cpp_nodes.
"""
return torch.nn.functional.upsample_bilinear(args[0],
size=self.size_list, scale_factor=self.scale_list)
c_node = node.key_node
inputs = list(c_node.inputs())
size_list_node = inputs[1].node()
scale_list_node = inputs[3].node()
size_list = None
scale_list = None
if size_list_node.kind() == 'prim::ListConstruct': for enum_value, dtype_name in enum_to_dtype_names.items():
size_list = translate_list(inputs[1], speedup) if hasattr(torch, dtype_name):
if scale_list_node.kind() == 'prim::ListConstruct': enum_to_dtype_dict[enum_value] = getattr(torch, dtype_name)
scale_list = translate_list(inputs[3], speedup)
return UpsampleModule(size_list, scale_list)
def dtype_trans(ivalue: Union[int, torch.dtype]):
"""
Special process for dtype.
Torch will transform dtype to an enum in cpp, so the value of dtype we get in jit is an int.
This function is used to recover the int to torch.dtype in python.
def upsample_nearest2d_python(node, speedup): Parameters
class UpsampleModule(torch.nn.Module): ----------
def __init__(self, size_list, scale_list): ivalue
super(UpsampleModule, self).__init__() The value of dtype or method to be recovered.
self.size_list = size_list
self.scale_list = scale_list
def forward(self, *args):
""" """
The first input of args is the target tensor to upsample if ivalue is None or isinstance(ivalue, torch.dtype):
, the following parameters is useless, because we already return ivalue
get the size_list and the scale_list by parsing the cpp_nodes. elif isinstance(ivalue, int):
if ivalue in enum_to_dtype_dict:
return enum_to_dtype_dict[ivalue]
raise TypeError('No torch.dtype corresponding to the value "%s"', ivalue)
enum_to_memory_format_dict = {
0: torch.contiguous_format,
1: torch.preserve_format,
2: torch.channels_last,
3: torch.channels_last_3d,
}
def memory_format_trans(ivalue: Union[int, torch.memory_format]):
""" """
return torch.nn.functional.upsample_nearest(args[0], Special process for memory_format.
size=self.size_list, scale_factor=self.scale_list) Torch will transform memory_format to an enum in cpp, so the value of memory_format we get in jit is an int.
c_node = node.key_node This function is used to recover the int to torch.memory_format in python.
inputs = list(c_node.inputs())
size_list_node = inputs[1].node()
scale_list_node = inputs[2].node()
size_list = None
scale_list = None
if size_list_node.kind() == 'prim::ListConstruct': Parameters
size_list = translate_list(inputs[1], speedup) ----------
if scale_list_node.kind() == 'prim::ListConstruct': ivalue
scale_list = translate_list(inputs[2], speedup) The value of memory_format or method to be recovered.
return UpsampleModule(size_list, scale_list)
"""
if ivalue is None or isinstance(ivalue, torch.memory_format):
return ivalue
elif isinstance(ivalue, int):
global enum_to_memory_format_dict
if ivalue in enum_to_memory_format_dict:
return enum_to_memory_format_dict[ivalue]
raise TypeError('No torch.memory_format corresponding to the value "%s"', ivalue)
special_treat_dict = {
'dtype': dtype_trans,
'memory_format': memory_format_trans,
}
def typeas_python(node, speedup): schema_fix_dict = {
# functinon 'to', 'randint', and 'sparse_coo_tensor' has different schema between python and c++.
# https://pytorch.org/docs/stable/jit_unsupported.html#ops-with-divergent-schemas-between-torch-python
"""aten::to.device(Tensor(a) self, Device device, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Ten
sor(a))""":
"""aten::to.device(Tensor(a) self, Device device, int dtype, bool non_blocking=False, bool copy=False, *, int? memory_format=None)
-> (Tensor(a))""",
'aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))':
'aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, *, int? memory_format=None) -> (Tensor(a))',
'aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))':
'aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, *, int? memory_format=None) -> (Tensor(a))',
# todo: are the arguments 'pin_memory' and 'requires_grad' related?
# functions in the python have only 'requires_grad' and functions in the aten have only 'pin_memory'
# 'aten::randint(int high, int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)',
# """aten::randint.generator(int high, int[] size, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, boo
# l? pin_memory=None) -> (Tensor)""",
# """aten::randint.low(int low, int high, int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None)
# -> (Tensor)""",
# """aten::randint.low_generator(int low, int high, int[] size, *, Generator? generator, int? dtype=None, int? layout=None, Device? dev
# ice=None, bool? pin_memory=None) -> (Tensor)""",
# """aten::sparse_coo_tensor.size(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=False) -> (Te
# nsor)""",
# """aten::sparse_coo_tensor.indices(Tensor indices, Tensor values, *, int? dtype=None, int? layout=None, Device? device=None, bool? pi
# n_memory=None) -> (Tensor)""",
# """aten::sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, int? dtype=None, int? layout=None, Device? devi
# ce=None, bool? pin_memory=None) -> (Tensor"""'
}
@lru_cache(maxsize=256)
def parse_aten_schema(schema: str):
""" """
currently only support type_as float. Parse the schema, to positional_num and keyword_list, and detect if the argument should be specially treated.
TODO: support more types in the type_as, need to figure out
how to get the scalar type from torch._C.TensorType.
""" """
class TypeasModule(torch.nn.Module): if schema in schema_fix_dict:
def __init__(self, dtype=torch.float): schema = schema_fix_dict[schema]
self.example = torch.zeros(1, dtype=dtype)
def forward(self, x):
return x.type_as(self.example)
return TypeasModule()
positional_num = 0
keyword_list = list()
special_treat = dict() # for dtype and memory_format trans now
def to_python(node, speedup): for arg in torch._C.parse_schema(schema).arguments:
# for the time being, only device parameters are supported if not arg.kwarg_only:
class ToModule(torch.nn.Module): key = positional_num
def __init__(self, device, dtype): positional_num += 1
super(ToModule, self).__init__() else:
self.device = device key = arg.name
self.dtype = dtype keyword_list.append(key)
def forward(self, x):
return x.to(device, dtype=self.dtype)
c_node = node.key_node
inputs = list(c_node.inputs())
in_debugname = inputs[0].debugName()
# device of the input tensor
device = speedup.internal_result[in_debugname].device
for _, _node in enumerate(inputs[1:]):
val = parse_constant(_node, speedup)
if isinstance(val, torch.device):
device = val
dtype = jitid_2_dtype[parse_constant(inputs[1], speedup)]
return ToModule(device, dtype)
if arg.name in special_treat_dict:
if key not in special_treat:
special_treat[key] = [special_treat_dict[arg.name]]
else:
special_treat[key].append(special_treat_dict[arg.name])
def cat_python(node, speedup): return positional_num, keyword_list, special_treat
class CatModule(torch.nn.Module):
def __init__(self, cat_dim):
super(CatModule, self).__init__()
self.cat_dim = cat_dim
def forward(self, *args): def parse_input_value(speedup: ModelSpeedup, input_nodes: List[torch._C.Node], positional_num: int, keyword_list: List[str]):
return torch.cat(args, dim=self.cat_dim) """
translate inputs, to constant positional arguments, constant keyword arguments, and undetermined positions
"""
positional = list()
keyword = dict()
undetermined = list()
for ainput in input_nodes:
if ainput.node().kind() == 'prim::ListConstruct':
arg = translate_list(ainput, speedup)
elif ainput.node().kind() == 'prim::Constant':
arg = ainput.toIValue()
else:
assert 'aten::' in ainput.node().kind() or 'prim::' in ainput.node().kind()
if len(positional) < positional_num:
undetermined.append(len(positional))
else:
undetermined.append(keyword_list[positional_num - len(positional)])
arg = None
c_node = node.key_node if len(positional) < positional_num:
inputs = list(c_node.inputs()) positional.append(arg)
dim = inputs[1].toIValue() else:
return CatModule(dim) keyword[keyword_list[positional_num - len(positional)]] = arg
return positional, keyword, undetermined
def special_treat_to_constant_value(positional: List, keyword: Dict[str], undetermined: List[Union[int, str]],
special_treat: Dict[Union[int, str], Callable]):
"""
if any argument with special_treat is not in undetermined, do the treat
"""
undetermined_special_treat = dict()
for p, fs in special_treat.items():
if p in undetermined:
undetermined_special_treat[p] = fs
elif isinstance(p, int):
for f in fs: positional[p] = f(positional[p])
else:
for f in fs: keyword[p] = f(keyword[p])
return undetermined_special_treat
def ones_python(node, speedup): def generate_aten_to_python(func: Callable, node: NodePyGroup, speedup: ModelSpeedup) -> FuncAdapter:
class OnesModule(torch.nn.Module): """
def __init__(self, out_size, dtype_id, device, require_grad): parse a Return a callable object to inference the mask according to the node.op_type.
super(OnesModule, self).__init__()
self.out_size = out_size
self.device = device
self.require_grad = require_grad
self.dtype = jitid_2_dtype[dtype_id]
def forward(self, *args): Parameters
return torch.ones(size=self.out_size, dtype=self.dtype, device=self.device, requires_grad=self.require_grad) ---------
func
The torch function one-to-one correspondence with the node.
node
The target node to inference the mask
speedup
The speedup object of the target model.
Returns
------
func
Return the translated function that used to inference the mask
, if current op_type is not supported, then we return None.
"""
c_node = node.key_node c_node = node.key_node
inputs = list(c_node.inputs())
output_shape = translate_list(inputs[0], speedup)
dtype_id = parse_constant(inputs[1], speedup)
# layout = parse_constant(inputs[2], speedup)
device = parse_constant(inputs[3], speedup)
require_grad = parse_constant(inputs[4], speedup)
return OnesModule(output_shape, dtype_id, device, require_grad)
def zeros_python(node, speedup):
class ZerosModule(torch.nn.Module):
def __init__(self, out_size, dtype_id, device, require_grad):
super(ZerosModule, self).__init__()
self.out_size = out_size
self.device = device
self.require_grad = require_grad
self.dtype = jitid_2_dtype[dtype_id]
def forward(self, *args): schema = c_node.schema()
return torch.zeros(size=self.out_size, dtype=self.dtype, device=self.device, requires_grad=self.require_grad) positional_num, keyword_list, special_treat = parse_aten_schema(schema)
c_node = node.key_node input_nodes = list(c_node.inputs())
inputs = list(c_node.inputs()) positional, keyword, undetermined = parse_input_value(speedup, input_nodes, positional_num, keyword_list)
output_shape = translate_list(inputs[0], speedup)
dtype_id = parse_constant(inputs[1], speedup)
# layout = parse_constant(inputs[2], speedup)
device = parse_constant(inputs[3], speedup)
require_grad = parse_constant(inputs[4], speedup)
return ZerosModule(output_shape, dtype_id, device, require_grad)
def rsub_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
constant = None
other_name = inputs[1].debugName()
alpha = parse_constant(inputs[2], speedup)
if other_name not in speedup.internal_result:
constant = parse_constant(inputs[1], speedup)
if constant is None:
return torch.sub()
else:
new_sub = partial(torch.sub, other=constant, alpha=alpha)
return new_sub
def expand_python(node, speedup): undetermined_special_treat = special_treat_to_constant_value(positional, keyword, undetermined, special_treat)
class ExpandModule(torch.nn.Module):
def __init__(self, new_size):
super(ExpandModule, self).__init__()
# need deepcopy when the input is size-related
self.new_size = copy.deepcopy(new_size)
def forward(self, *args): return FuncAdapter(func, positional, keyword, undetermined, undetermined_special_treat)
return args[0].expand(self.new_size).clone()
c_node = node.key_node trans_func_dict = {
inputs = list(c_node.inputs())
new_size = translate_list(inputs[1], speedup)
return ExpandModule(new_size)
def expandas_python(node, speedup):
class ExpandasModule(torch.nn.Module):
def forward(self, x, y):
return x.expand_as(y).clone()
return ExpandasModule()
trans_from_jit_to_python = {
'aten::add': add_python,
'aten::add_': add_python,
'aten::sub': sub_python,
'aten::sub_': sub_python,
'aten::mul': mul_python,
'aten::mul_': mul_python,
'aten::relu': relu_python,
'aten::relu_': relu_inplace_python,
'aten::sigmoid': sigmoid_python,
'aten::sigmoid_': sigmoid_python,
# tanh behaives like relu
'aten::tanh': relu_python,
'aten::tanh_': relu_python,
'aten::flatten': flatten_python,
'aten::mean': mean_python,
'aten::dropout': dropout_python,
'aten::slice': slice_python, 'aten::slice': slice_python,
'aten::select': select_python,
'aten::size': size_python,
'aten::t': transpose_python,
'aten::transpose': transpose2_python,
'aten::Int': toint_python,
'aten::view': view_python,
'aten::reshape': reshape_python,
'aten::permute': permute_python,
'aten::matmul': matmul_python,
'aten::div': div_python,
'aten::floor_divide': floor_div_python,
'aten::softmax': softmax_python,
'aten::contiguous': contiguous_python,
'aten::gelu': gelu_python,
'aten::cat': cat_python, 'aten::cat': cat_python,
'aten::avg_pool2d': avgpool2d_python, 'aten::Int': partial(generate_aten_to_python, torch._C._TensorBase.int),
'aten::max_pool2d': avgpool2d_python,
'aten::adaptive_avg_pool2d': adaptive_avgpool_python,
'aten::to': to_python,
'aten::type_as': typeas_python,
'aten::upsample_bilinear2d': upsample_bilinear2d_python,
'aten::upsample_nearest2d': upsample_nearest2d_python,
'aten::exp': exp_python,
'aten::squeeze': squeeze_python,
'aten::unsqueeze': unsqueeze_python,
'aten::ones': ones_python,
'aten::zeros': zeros_python,
'aten::rsub': rsub_python,
'aten::expand': expand_python,
'prim::TupleUnpack': tupleunpack_python, 'prim::TupleUnpack': tupleunpack_python,
'prim::ListUnpack': tupleunpack_python, 'prim::ListUnpack': tupleunpack_python,
'prim::NumToTensor': num2tensor_python, 'prim::NumToTensor': num2tensor_python,
'prim::GetAttr': getattr_python, 'prim::GetAttr': getattr_python,
'prim::Constant': constant_python,
'aten::constant_pad_nd': constant_pad_nd_python,
'aten::silu': silu_python,
'aten::expand_as': expandas_python
} }
def init_add_functions(func_from: Union[ModuleType, Type[Any]]):
"""
Add function/method attributes from a module/class, to the trans_func_dict
Parameters
---------
func_from
The module/class include needed functions
def jit_to_python_function(node, speedup):
""" """
Return a callable object to inference the mask according to the global trans_func_dict
node.op_type. new_trans_func_dict = dict()
for name in dir(func_from):
attr = getattr(func_from, name)
if callable(attr) and not name.startswith('__'):
new_trans_func_dict['aten::' + name] = partial(generate_aten_to_python, attr)
trans_func_dict = {**new_trans_func_dict, **trans_func_dict}
init_add_functions(torch._C._VariableFunctions)
init_add_functions(torch._C._nn)
init_add_functions(torch._C._TensorBase)
def jit_to_python_function(node: NodePyGroup, speedup: ModelSpeedup) -> FuncAdapter:
"""
Return a callable object to inference the mask according to the node.op_type.
Parameters Parameters
--------- ---------
node: NodeGroup node
The target node to inference the mask The target node to inference the mask
speedup: ModelSpeedup speedup
The speedup object of the target model. The speedup object of the target model.
Returns Returns
------ ------
func: callable object(nn.Module/function) func
Return the translated function that used to inference the mask Return the translated function that used to inference the mask
, if current op_type is not supported, then we return None. , if current op_type is not supported, then we return None.
""" """
logger.debug( logger.debug(
'Translate C function %s into its python version', node.op_type) 'Translate C function %s into its python version', node.op_type)
if node.op_type not in trans_from_jit_to_python: if node.op_type not in trans_func_dict:
logger.error( logger.error(
'%s is not Supported! Please report an issue at https://github.com/microsoft/nni. Thanks~', node.op_type) '%s is not Supported! Please report an issue at https://github.com/microsoft/nni. Thanks~', node.op_type)
# return None to skip the mask inference for this node # return None to skip the mask inference for this node
return None return None
return trans_from_jit_to_python[node.op_type](node, speedup) return trans_func_dict[node.op_type](node, speedup)
...@@ -61,7 +61,7 @@ class BackboneModel2(torch.nn.Module): ...@@ -61,7 +61,7 @@ class BackboneModel2(torch.nn.Module):
x = F.max_pool2d(x, 2, 2) x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x))) x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2) x = F.max_pool2d(x, 2, 2)
x = x.view(x.size(0), -1) x = x.reshape(x.size(0), -1)
x = F.relu(self.fc1(x)) x = F.relu(self.fc1(x))
x = self.fc2(x) x = self.fc2(x)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import unittest
import torch
import torch.nn.functional as F
from nni.compression.pytorch.pruning import L1NormPruner
from nni.compression.pytorch.speedup import ModelSpeedup
from nni.algorithms.compression.v2.pytorch.utils import (
compute_sparsity_compact2origin,
compute_sparsity_mask2compact
)
class CondModel(torch.nn.Module):
"""
test for:
prim::If
"""
the_cond: bool
def __init__(self):
super().__init__()
self.the_cond = True
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.the_cond:
x = x + 0.00001
else:
x = x - 0.00001
self.the_cond = not self.the_cond
return x
class ASubModel(torch.nn.Module):
"""
test for:
sub model
"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + 0.00001
return x
class TorchModel1(torch.nn.Module):
"""
test for:
add, sub, mul, div, exp, matmul,
relu, gelu, tanh, silu, sigmod, softmax,
size, unsqueeze, flatten, cat, slice, reshape, transpose, t, select, permute, constant_pad_nd,
mean, avg_pool2d, max_pool2d, sum, adaptive_avg_pool2d,
to, Int, view,
type_as, expand_as, contiguous,
notes:
'floor_divide' have no backward, then not be tested
"""
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 6, 5, 1)
self.conv2 = torch.nn.Conv2d(6, 16, 5, 1)
self.fccond = torch.nn.Linear(16 * 4 * 4, 16 * 4 * 4)
self.fc1 = torch.nn.Linear(16 * 4 * 4, 120)
self.fc2 = torch.nn.Linear(120, 84)
self.fc3 = torch.nn.Linear(84, 10)
self.pool1 = torch.nn.MaxPool2d((2, 2))
self.pool2 = torch.nn.MaxPool2d((2, 2))
self.cond = torch.jit.script(CondModel())
self.asub = ASubModel()
def forward(self, x: torch.Tensor):
x = x.contiguous(memory_format=torch.channels_last)
x = torch._C._nn.upsample_bilinear2d(x, (28, 28), False)
x = torch._C._nn.upsample_nearest2d(x, (28, 28))
x = F.adaptive_avg_pool2d(x, (28, 28))
x = torch.exp(x)
x = torch.sigmoid(x)
x = torch.transpose(x, 1, 2)
x = torch.transpose(x, 1, 2)
x = F.avg_pool2d(x, 3, 1, padding=1)
x = F.max_pool2d(x, 3, 1, padding=1)
x = x.to(torch.float32)
x = self.conv1(x)
y1 = self.pool1(F.relu(x))
y2 = self.pool1(F.gelu(x))
x = y1 + y2
x = x + 0.00001
x = x * 1.00001
x = self.conv2(x)
y1 = self.pool2(F.silu(x))
y2 = self.pool2(torch.tanh(x))
x = y1 - y2
x = x - 0.00001
x = x / 1.00001
x = torch.permute(x, (0, 2, 3, 1))
x = torch.permute(x, (0, 2, 3, 1))
x = torch.permute(x, (0, 2, 3, 1))
x = torch.unsqueeze(x, dim=1)
x = torch.select(x, dim=1, index=0)
x = torch.unsqueeze(x, dim=1)
x = torch.mean(x, dim=1)
x = torch.unsqueeze(x, dim=1)
x = torch.sum(x, dim=1, dtype=torch.float32)
x = torch.unsqueeze(x, dim=1)
x = torch.squeeze(x, dim=1)
x = torch.flatten(x, 1)
x = x.reshape(x.shape)
x = x.view(-1, x.size(1))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.softmax(self.fc3(x), dim=1)
y1 = x[:,0:int(x.size(1)/2)]
y2 = x[:,int(x.size(1)/2):x.size(1)]
x = torch.cat((y1, y2), dim=1)
x = x.type_as(x)
x = x.expand_as(x)
x = torch.matmul(x, x.t())
x = torch.cat([x, x], dim=1)
# x = self.cond(x)
x = self.asub(x)
x = torch.constant_pad_nd(x, (1,1,1,1), 3.14159)
return x
class AutoConvTestCase(unittest.TestCase):
def test_l1norm_pruner(self):
model = TorchModel1()
dummy_input = torch.rand(3, 1, 28, 28)
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.5}]
pruner = L1NormPruner(model=model, config_list=config_list)
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
ModelSpeedup(model, dummy_input, masks).speedup_model()
real_sparsity_list = compute_sparsity_compact2origin(TorchModel1(), model, config_list)
print('sparsity_list:', sparsity_list)
assert 0.45 < sparsity_list[0]['total_sparsity'] < 0.55
print('real_sparsity_list:', real_sparsity_list)
assert 0.45 < real_sparsity_list[0]['total_sparsity'] < 0.75
print('the shape of output of the infer:', model(dummy_input).shape)
assert model(dummy_input).shape == torch.Size((5, 8))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment