Unverified Commit 588f299b authored by Louis-J's avatar Louis-J Committed by GitHub
Browse files

feat(speedup): automatically convert op asts to callables (#4996)

parent b2c31ca2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Type, Union
from typing import TYPE_CHECKING
if TYPE_CHECKING: # Only imports the below statements during type checking
from nni.compression.pytorch.speedup import ModelSpeedup
from nni.common.graph_utils import NodePyGroup
import re
import logging
from functools import partial
from functools import partial, lru_cache
import copy
import torch
......@@ -15,31 +24,24 @@ jitid_2_dtype = {4: torch.long, 6:torch.float32}
# to exclude partial
__all__ = [
'adaptive_avgpool_python', 'add_python', 'avgpool2d_python', 'cat_python', 'contiguous_python',
'div_python', 'dropout_python', 'exp_python', 'flatten_python', 'floor_div_python', 'gelu_python',
'getattr_python', 'jit_to_python_function', 'matmul_python', 'mean_python',
'mul_python', 'num2tensor_python', 'parse_constant', 'permute_python', 'relu_inplace_python',
'relu_python', 'reshape_python', 'select_python', 'sigmoid_python', 'size_python', 'slice_python',
'softmax_python', 'squeeze_python', 'to_python', 'toint_python', 'torch', 'trans_from_jit_to_python',
'translate_list', 'transpose2_python', 'transpose_python', 'tupleunpack_python', 'typeas_python',
'unsqueeze_python', 'upsample_bilinear2d_python', 'view_python'
'getattr_python', 'jit_to_python_function', 'num2tensor_python', 'parse_constant', 'slice_python',
'translate_list', 'tupleunpack_python', 'dtype_trans', 'memory_format_trans'
]
def translate_list(list_node, speedup=None):
def translate_list(list_node: torch._C.Value, speedup: ModelSpeedup=None) -> List:
"""
Get the list of values from the list construct node.
Parameters
----------
list_node: Torch.C.Value
list_node
The cpp node of the target list.
speedup: ModuleSpeed
speedup
The Module speedup module.
Returns
-------
values: list
values
The list of values in the target cpp list node.
"""
# the node that create the list
......@@ -52,27 +54,26 @@ def translate_list(list_node, speedup=None):
if speedup is not None and debugName in speedup.internal_result:
# this value is the result of the other nodes, such as
# ate::size
values.append(speedup.internal_result[debugName].item())
values.append(speedup.internal_result[debugName])
else:
# if the corresponding value is a constant
values.append(_i.toIValue())
return values
def parse_constant(cvalue, speedup):
def parse_constant(cvalue: torch._C.Value, speedup: ModelSpeedup) -> Any:
"""
Parse the constant values from this Node
Parameters
----------
cvalue: Torch.C.Value
cvalue
The cpp node of the target constant value.
speedup: ModelSpeedup
speedup
The Model speedup module.
Returns
-------
value: int/float/tensor
value
The constant values parsed from the node.
"""
logger.debug('Try to parse the constant value: %s', cvalue.debugName())
......@@ -85,245 +86,13 @@ def parse_constant(cvalue, speedup):
inputs = op_node.inputs()
input_values = [parse_constant(_i, speedup) for _i in inputs]
func = trans_from_jit_to_python[op_node.kind()](op_node, speedup)
return func(*input_values)
def dropout_python(node, speedup):
return torch.dropout
def flatten_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
start_dim = inputs[1].toIValue()
end_dim = inputs[2].toIValue()
new_flatten = partial(torch.flatten, start_dim=start_dim, end_dim=end_dim)
return new_flatten
def relu_inplace_python(node, speedup):
return torch.relu_
def relu_python(node, speedup):
return torch.relu
def sigmoid_python(node, speedup):
return torch.sigmoid
def mean_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim_list = translate_list(inputs[1], speedup)
keep_dim = inputs[2].toIValue()
new_mean = partial(torch.mean, dim=tuple(dim_list), keepdim=keep_dim)
return new_mean
def add_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
constant = None
for i in range(2):
input_i = inputs[i]
debug_name = input_i.debugName()
if debug_name not in speedup.internal_result:
# this input is a constant value
# TODO: what if this input is a constant tensor
if input_i.toIValue() is not None:
constant = parse_constant(input_i, speedup)
break
if constant is None:
return torch.add
else:
new_add = partial(torch.add, constant)
return new_add
def sub_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
constant = [None, None]
for i in range(2):
input_i = inputs[i]
debug_name = input_i.debugName()
if debug_name not in speedup.internal_result:
# this input is a constant value
# TODO: what if this input is a constant tensor
if input_i.toIValue() is not None:
constant[i] = parse_constant(input_i, speedup)
break
if constant[0] is None and constant[1] is None:
new_sub = torch.sub
elif constant[0] is not None:
new_sub = partial(torch.sub, input=constant)
else:
new_sub = partial(torch.sub, other=constant)
return new_sub
def floor_div_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
divisor = inputs[1]
constant = None
if divisor.debugName() not in speedup.internal_result:
# divisor is a constant value/tensor
constant = parse_constant(divisor, speedup)
if constant is None:
return torch.floor_divide
else:
new_op = partial(torch.floor_divide, other=constant)
return new_op
def mul_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
constant = None
for i in range(2):
input_i = inputs[i]
debug_name = input_i.debugName()
if debug_name not in speedup.internal_result:
constant = parse_constant(input_i, speedup)
# both two inputs cannot be constants at the same time
break
if constant is None:
return torch.mul
else:
new_mul = partial(torch.mul, constant)
return new_mul
def transpose_python(node, speedup):
return torch.t
def transpose2_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim_1 = inputs[1].toIValue()
dim_2 = inputs[2].toIValue()
new_transpose = partial(torch.transpose, dim0=dim_1, dim1=dim_2)
return new_transpose
def matmul_python(node, speedup):
return torch.matmul
def div_python(node, speedup):
# The second input parameter of torch.div can be a
# tensor or a constant, if it is a constant, we need
# to return
c_node = node.key_node
inputs = list(c_node.inputs())
if inputs[1].debugName() in speedup.internal_result:
# the second input parameters is the output of the other
# nodes
return torch.div
else:
other = inputs[1].toIValue()
new_div = partial(torch.div, other=other)
return new_div
def softmax_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim = inputs[1].toIValue()
new_softmax = partial(torch.softmax, dim=dim)
return new_softmax
def contiguous_python(node, speedup):
class contiguousModule(torch.nn.Module):
def forward(self, x):
return x.contiguous().clone()
return contiguousModule()
def gelu_python(node, speedup):
return torch.nn.GELU()
def silu_python(node, speedup):
return torch.nn.SiLU()
def avgpool2d_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
kernel_size = translate_list(inputs[1], speedup)
stride = translate_list(inputs[2], speedup)
padding = translate_list(inputs[3], speedup)
new_avgpool = partial(torch.nn.functional.avg_pool2d,
kernel_size=kernel_size, stride=stride, padding=padding)
return new_avgpool
def adaptive_avgpool_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
output_size = translate_list(inputs[1], speedup)
new_avgpool = torch.nn.AdaptiveAvgPool2d(output_size)
return new_avgpool
def tupleunpack_python(node, speedup):
# Note: tuple unpack should only exists at the
# the end of the model, and is no need to replace/propagate mask
return None
def num2tensor_python(node, speedup):
return torch.nn.Identity()
def exp_python(node, speedup):
return torch.exp
def squeeze_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim = None
if len(inputs) > 1:
dim = parse_constant(inputs[1], speedup)
new_squeeze = partial(torch.squeeze, dim=dim)
return new_squeeze
def unsqueeze_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim = parse_constant(inputs[1], speedup)
new_unsqueeze = partial(torch.unsqueeze, dim=dim)
return new_unsqueeze
def constant_pad_nd_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
pad = translate_list(inputs[1], speedup)
value = parse_constant(inputs[2], speedup)
new_constant_pad_nd = partial(torch.nn.functional.pad, pad=pad, value=value)
return new_constant_pad_nd
##########################################################
# Split Line
# Following module/functions cannot be translated into a
# single function, so we use torch.nn.Module to wrap the
# the core function, and return the torch.nn.Module instead
##########################################################
if op_node.kind() not in trans_func_dict:
raise RuntimeError('Unsupported function op node type: {}'.format(op_node.kind()))
func = trans_func_dict[op_node.kind()](op_node, speedup)
return func(*input_values)
def slice_python(node, speedup):
def slice_python(node: NodePyGroup, speedup: ModelSpeedup):
class SliceMoudle(torch.nn.Module):
def __init__(self, sliceobj):
super(SliceMoudle, self).__init__()
......@@ -368,102 +137,38 @@ def slice_python(node, speedup):
else:
return SliceMoudle(tuple(slice_list))
def select_python(node, speedup):
class SelectModule(torch.nn.Module):
def __init__(self, dim, index):
super(SelectModule, self).__init__()
self.dim = copy.deepcopy(dim)
self.index = copy.deepcopy(index)
def forward(self, x):
return x.select(self.dim, self.index)
c_node = node.key_node
inputs = list(c_node.inputs())
dim = inputs[1].toIValue()
index = inputs[2].toIValue()
return SelectModule(dim, index)
def size_python(node, speedup):
# return None
class SizeMoudle(torch.nn.Module):
def __init__(self, sizedim):
super(SizeMoudle, self).__init__()
self.sizedim = sizedim
def forward(self, x):
return torch.as_tensor([x.size(self.sizedim)], dtype=torch.long)
# return torch.tensor(x.size(self.sizedim))
c_node = node.key_node
inputs = list(c_node.inputs())
size_dim = inputs[1].toIValue()
return SizeMoudle(size_dim)
def toint_python(node, speedup):
class ToIntModule(torch.nn.Module):
def forward(self, x):
return x.to(torch.int)
return ToIntModule()
def view_python(node, speedup):
class ViewModule(torch.nn.Module):
def __init__(self, shape):
super(ViewModule, self).__init__()
self.shape = shape
logger.info('View Module output size: %s', str(self.shape))
def cat_python(node: NodePyGroup, speedup: ModelSpeedup):
class CatModule(torch.nn.Module):
def __init__(self, cat_dim):
super(CatModule, self).__init__()
self.cat_dim = cat_dim
def forward(self, *args):
return args[0].view(self.shape)
c_node = node.key_node
inputs = list(c_node.inputs())
shape = translate_list(inputs[1], speedup)
return ViewModule(shape)
def reshape_python(node, speedup):
class ReshapeModule(torch.nn.Module):
def __init__(self, shape):
super(ReshapeModule, self).__init__()
self.shape = shape
logger.info('Reshape Module output size: %s', str(self.shape))
return torch.cat(args, dim=self.cat_dim)
def forward(self, *args):
return args[0].reshape(self.shape)
c_node = node.key_node
inputs = list(c_node.inputs())
shape = translate_list(inputs[1], speedup)
return ReshapeModule(shape)
def permute_python(node, speedup):
class PermuteModule(torch.nn.Module):
def __init__(self, dimlist):
super(PermuteModule, self).__init__()
# deepcopy the values here, because the following randomize operation
# will change the value of the dimlist
self.dimlist = copy.deepcopy(dimlist)
dim = inputs[1].toIValue()
return CatModule(dim)
def forward(self, x):
return x.permute(self.dimlist)
c_node = node.key_node
inputs = list(c_node.inputs())
dim_list = translate_list(inputs[1], speedup)
return PermuteModule(dim_list)
def tupleunpack_python(_node: NodePyGroup, _speedup: ModelSpeedup) -> Optional[Callable]:
# Note: tuple unpack should only exists at the
# the end of the model, and is no need to replace/propagate mask
return None
def num2tensor_python(_node: NodePyGroup, _speedup: ModelSpeedup):
return torch.nn.Identity()
def getattr_python(node, speedup):
def getattr_python(node: NodePyGroup, _speedup: ModelSpeedup):
"""
Note: Ops started with Prim:: is not taken as the key node,
so we directly pass the Cpp node into this funciton.
Parameters
----------
node: torch._C.Node
node
The cpp node of prim::Getattr
speedup: ModelSpeedup
speedup
The corresponding speedup object.
"""
class GetModule(torch.nn.Module):
......@@ -482,316 +187,332 @@ def getattr_python(node, speedup):
assert len(key_words) == 1
return GetModule(key_words[0])
def constant_python(node, speedup):
class FuncAdapter:
"""
get the constant value of constant operator node.
A function adapter which can reorder arguments.
It can be initialate with constant argument, and positions of each non-constant
argument. When called, it can put arguments into correct position, then call the
function.
Parameters
Attributes
----------
node: torch._C.Node
The cpp node of prim::Getattr
speedup: ModelSpeedup
The corresponding speedup object.
"""
class ConstantModule(torch.nn.Module):
def __init__(self, constant):
super(ConstantModule, self).__init__()
self.constant = constant
def forward(self):
return self.constant
assert node.kind() == 'prim::Constant'
pattern = '\[value=(.*?)\]'
key_words = re.findall(pattern, str(node))
if len(key_words) == 0:
return ConstantModule(None)
assert len(key_words) == 1
# parse the constant value
value = key_words[0]
if value.startswith("\""):
value = torch.device(value[1:-1])
elif value.startswith('{'):
# TODO Support set values in the future
value = set()
elif '.' in value:
# float value
value = float(value)
else:
# integer value
value = int(value)
return ConstantModule(value)
def upsample_bilinear2d_python(node, speedup):
class UpsampleModule(torch.nn.Module):
def __init__(self, size_list, scale_list):
super(UpsampleModule, self).__init__()
self.size_list = size_list
self.scale_list = scale_list
def forward(self, *args):
"""
The first input of args is the target tensor to upsample
, the following parameters is useless, because we already
get the size_list and the scale_list by parsing the cpp_nodes.
"""
return torch.nn.functional.upsample_bilinear(args[0],
size=self.size_list, scale_factor=self.scale_list)
c_node = node.key_node
inputs = list(c_node.inputs())
size_list_node = inputs[1].node()
scale_list_node = inputs[3].node()
size_list = None
scale_list = None
func
The function or method to be called.
positional
Positional arguments values. The placeholder is None if it's non-constant.
keyword
Keyword arguments values. The placeholder is None if it's non-constant.
undetermined
A list of the right positions of arguments.
Position is an int in positional or a str in keyword.
special_treat
A Dict of the positions and methods.
The values of these positions should be treat by those methods.
if size_list_node.kind() == 'prim::ListConstruct':
size_list = translate_list(inputs[1], speedup)
if scale_list_node.kind() == 'prim::ListConstruct':
scale_list = translate_list(inputs[3], speedup)
return UpsampleModule(size_list, scale_list)
"""
def __init__(self, func: Callable, positional: List[Any], keyword: Dict[str, Any],
undetermined: List[Union[int, str]], special_treat: Dict[Union[int, str], Callable]):
if not callable(func):
raise TypeError('the "func" argument must be callable')
self.func = func
self.positional = positional
self.keyword = keyword
self.undetermined = undetermined
self.special_treat = special_treat
def __call__(self, /, *args):
assert len(args) >= len(self.undetermined)
if len(args) > len(self.undetermined):
logger.warning('throw some args away when calling the function "%s"', self.func.__name__)
for i, p in enumerate(self.undetermined):
v = args[i]
if isinstance(p, int):
self.positional[p] = v
else:
self.keyword[p] = v
for p, fs in self.special_treat.items():
if isinstance(p, int):
for f in fs:
self.positional[p] = f(self.positional[p])
else:
for f in fs:
self.keyword[p] = f(self.keyword[p])
result = self.func(*self.positional, **self.keyword)
if isinstance(result, int): # turn result of 'size' into tensor
result = torch.as_tensor([result], dtype=torch.long)
return result
# There are some types that will be convert into enums after jit.
# So we should recover them back:
# device, dtype, layout, memory_format, qscheme, qengine, dispatchkey
enum_to_dtype_names = {
0: 'uint8',
1: 'int8',
2: 'int16',
3: 'int32',
4: 'int64',
5: 'float16',
6: 'float32',
7: 'float64',
8: 'complex32',
9: 'complex64',
10: 'complex128',
11: 'bool',
12: 'qint8',
13: 'quint8',
14: 'qint32',
15: 'bfloat16',
16: 'quint4x2',
17: 'quint2x4',
}
def upsample_nearest2d_python(node, speedup):
class UpsampleModule(torch.nn.Module):
def __init__(self, size_list, scale_list):
super(UpsampleModule, self).__init__()
self.size_list = size_list
self.scale_list = scale_list
enum_to_dtype_dict = {}
def forward(self, *args):
"""
The first input of args is the target tensor to upsample
, the following parameters is useless, because we already
get the size_list and the scale_list by parsing the cpp_nodes.
"""
return torch.nn.functional.upsample_nearest(args[0],
size=self.size_list, scale_factor=self.scale_list)
c_node = node.key_node
inputs = list(c_node.inputs())
size_list_node = inputs[1].node()
scale_list_node = inputs[2].node()
size_list = None
scale_list = None
for enum_value, dtype_name in enum_to_dtype_names.items():
if hasattr(torch, dtype_name):
enum_to_dtype_dict[enum_value] = getattr(torch, dtype_name)
if size_list_node.kind() == 'prim::ListConstruct':
size_list = translate_list(inputs[1], speedup)
if scale_list_node.kind() == 'prim::ListConstruct':
scale_list = translate_list(inputs[2], speedup)
return UpsampleModule(size_list, scale_list)
def dtype_trans(ivalue: Union[int, torch.dtype]):
"""
Special process for dtype.
Torch will transform dtype to an enum in cpp, so the value of dtype we get in jit is an int.
This function is used to recover the int to torch.dtype in python.
Parameters
----------
ivalue
The value of dtype or method to be recovered.
def typeas_python(node, speedup):
"""
currently only support type_as float.
TODO: support more types in the type_as, need to figure out
how to get the scalar type from torch._C.TensorType.
"""
class TypeasModule(torch.nn.Module):
def __init__(self, dtype=torch.float):
self.example = torch.zeros(1, dtype=dtype)
if ivalue is None or isinstance(ivalue, torch.dtype):
return ivalue
elif isinstance(ivalue, int):
if ivalue in enum_to_dtype_dict:
return enum_to_dtype_dict[ivalue]
raise TypeError('No torch.dtype corresponding to the value "%s"', ivalue)
enum_to_memory_format_dict = {
0: torch.contiguous_format,
1: torch.preserve_format,
2: torch.channels_last,
3: torch.channels_last_3d,
}
def forward(self, x):
return x.type_as(self.example)
return TypeasModule()
def memory_format_trans(ivalue: Union[int, torch.memory_format]):
"""
Special process for memory_format.
Torch will transform memory_format to an enum in cpp, so the value of memory_format we get in jit is an int.
This function is used to recover the int to torch.memory_format in python.
Parameters
----------
ivalue
The value of memory_format or method to be recovered.
def to_python(node, speedup):
# for the time being, only device parameters are supported
class ToModule(torch.nn.Module):
def __init__(self, device, dtype):
super(ToModule, self).__init__()
self.device = device
self.dtype = dtype
def forward(self, x):
return x.to(device, dtype=self.dtype)
"""
if ivalue is None or isinstance(ivalue, torch.memory_format):
return ivalue
elif isinstance(ivalue, int):
global enum_to_memory_format_dict
if ivalue in enum_to_memory_format_dict:
return enum_to_memory_format_dict[ivalue]
raise TypeError('No torch.memory_format corresponding to the value "%s"', ivalue)
special_treat_dict = {
'dtype': dtype_trans,
'memory_format': memory_format_trans,
}
c_node = node.key_node
inputs = list(c_node.inputs())
in_debugname = inputs[0].debugName()
# device of the input tensor
device = speedup.internal_result[in_debugname].device
schema_fix_dict = {
# functinon 'to', 'randint', and 'sparse_coo_tensor' has different schema between python and c++.
# https://pytorch.org/docs/stable/jit_unsupported.html#ops-with-divergent-schemas-between-torch-python
"""aten::to.device(Tensor(a) self, Device device, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Ten
sor(a))""":
"""aten::to.device(Tensor(a) self, Device device, int dtype, bool non_blocking=False, bool copy=False, *, int? memory_format=None)
-> (Tensor(a))""",
'aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))':
'aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, *, int? memory_format=None) -> (Tensor(a))',
'aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))':
'aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, *, int? memory_format=None) -> (Tensor(a))',
# todo: are the arguments 'pin_memory' and 'requires_grad' related?
# functions in the python have only 'requires_grad' and functions in the aten have only 'pin_memory'
# 'aten::randint(int high, int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)',
# """aten::randint.generator(int high, int[] size, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, boo
# l? pin_memory=None) -> (Tensor)""",
# """aten::randint.low(int low, int high, int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None)
# -> (Tensor)""",
# """aten::randint.low_generator(int low, int high, int[] size, *, Generator? generator, int? dtype=None, int? layout=None, Device? dev
# ice=None, bool? pin_memory=None) -> (Tensor)""",
# """aten::sparse_coo_tensor.size(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=False) -> (Te
# nsor)""",
# """aten::sparse_coo_tensor.indices(Tensor indices, Tensor values, *, int? dtype=None, int? layout=None, Device? device=None, bool? pi
# n_memory=None) -> (Tensor)""",
# """aten::sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, int? dtype=None, int? layout=None, Device? devi
# ce=None, bool? pin_memory=None) -> (Tensor"""'
}
@lru_cache(maxsize=256)
def parse_aten_schema(schema: str):
"""
Parse the schema, to positional_num and keyword_list, and detect if the argument should be specially treated.
"""
if schema in schema_fix_dict:
schema = schema_fix_dict[schema]
for _, _node in enumerate(inputs[1:]):
val = parse_constant(_node, speedup)
if isinstance(val, torch.device):
device = val
dtype = jitid_2_dtype[parse_constant(inputs[1], speedup)]
return ToModule(device, dtype)
positional_num = 0
keyword_list = list()
special_treat = dict() # for dtype and memory_format trans now
for arg in torch._C.parse_schema(schema).arguments:
if not arg.kwarg_only:
key = positional_num
positional_num += 1
else:
key = arg.name
keyword_list.append(key)
def cat_python(node, speedup):
class CatModule(torch.nn.Module):
def __init__(self, cat_dim):
super(CatModule, self).__init__()
self.cat_dim = cat_dim
if arg.name in special_treat_dict:
if key not in special_treat:
special_treat[key] = [special_treat_dict[arg.name]]
else:
special_treat[key].append(special_treat_dict[arg.name])
def forward(self, *args):
return torch.cat(args, dim=self.cat_dim)
return positional_num, keyword_list, special_treat
c_node = node.key_node
inputs = list(c_node.inputs())
dim = inputs[1].toIValue()
return CatModule(dim)
def parse_input_value(speedup: ModelSpeedup, input_nodes: List[torch._C.Node], positional_num: int, keyword_list: List[str]):
"""
translate inputs, to constant positional arguments, constant keyword arguments, and undetermined positions
"""
positional = list()
keyword = dict()
undetermined = list()
for ainput in input_nodes:
if ainput.node().kind() == 'prim::ListConstruct':
arg = translate_list(ainput, speedup)
elif ainput.node().kind() == 'prim::Constant':
arg = ainput.toIValue()
else:
assert 'aten::' in ainput.node().kind() or 'prim::' in ainput.node().kind()
if len(positional) < positional_num:
undetermined.append(len(positional))
else:
undetermined.append(keyword_list[positional_num - len(positional)])
arg = None
if len(positional) < positional_num:
positional.append(arg)
else:
keyword[keyword_list[positional_num - len(positional)]] = arg
return positional, keyword, undetermined
def special_treat_to_constant_value(positional: List, keyword: Dict[str], undetermined: List[Union[int, str]],
special_treat: Dict[Union[int, str], Callable]):
"""
if any argument with special_treat is not in undetermined, do the treat
"""
undetermined_special_treat = dict()
for p, fs in special_treat.items():
if p in undetermined:
undetermined_special_treat[p] = fs
elif isinstance(p, int):
for f in fs: positional[p] = f(positional[p])
else:
for f in fs: keyword[p] = f(keyword[p])
return undetermined_special_treat
def ones_python(node, speedup):
class OnesModule(torch.nn.Module):
def __init__(self, out_size, dtype_id, device, require_grad):
super(OnesModule, self).__init__()
self.out_size = out_size
self.device = device
self.require_grad = require_grad
self.dtype = jitid_2_dtype[dtype_id]
def generate_aten_to_python(func: Callable, node: NodePyGroup, speedup: ModelSpeedup) -> FuncAdapter:
"""
parse a Return a callable object to inference the mask according to the node.op_type.
def forward(self, *args):
return torch.ones(size=self.out_size, dtype=self.dtype, device=self.device, requires_grad=self.require_grad)
Parameters
---------
func
The torch function one-to-one correspondence with the node.
node
The target node to inference the mask
speedup
The speedup object of the target model.
Returns
------
func
Return the translated function that used to inference the mask
, if current op_type is not supported, then we return None.
"""
c_node = node.key_node
inputs = list(c_node.inputs())
output_shape = translate_list(inputs[0], speedup)
dtype_id = parse_constant(inputs[1], speedup)
# layout = parse_constant(inputs[2], speedup)
device = parse_constant(inputs[3], speedup)
require_grad = parse_constant(inputs[4], speedup)
return OnesModule(output_shape, dtype_id, device, require_grad)
def zeros_python(node, speedup):
class ZerosModule(torch.nn.Module):
def __init__(self, out_size, dtype_id, device, require_grad):
super(ZerosModule, self).__init__()
self.out_size = out_size
self.device = device
self.require_grad = require_grad
self.dtype = jitid_2_dtype[dtype_id]
def forward(self, *args):
return torch.zeros(size=self.out_size, dtype=self.dtype, device=self.device, requires_grad=self.require_grad)
schema = c_node.schema()
positional_num, keyword_list, special_treat = parse_aten_schema(schema)
c_node = node.key_node
inputs = list(c_node.inputs())
output_shape = translate_list(inputs[0], speedup)
dtype_id = parse_constant(inputs[1], speedup)
# layout = parse_constant(inputs[2], speedup)
device = parse_constant(inputs[3], speedup)
require_grad = parse_constant(inputs[4], speedup)
return ZerosModule(output_shape, dtype_id, device, require_grad)
def rsub_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
constant = None
other_name = inputs[1].debugName()
alpha = parse_constant(inputs[2], speedup)
if other_name not in speedup.internal_result:
constant = parse_constant(inputs[1], speedup)
if constant is None:
return torch.sub()
else:
new_sub = partial(torch.sub, other=constant, alpha=alpha)
return new_sub
input_nodes = list(c_node.inputs())
positional, keyword, undetermined = parse_input_value(speedup, input_nodes, positional_num, keyword_list)
def expand_python(node, speedup):
class ExpandModule(torch.nn.Module):
def __init__(self, new_size):
super(ExpandModule, self).__init__()
# need deepcopy when the input is size-related
self.new_size = copy.deepcopy(new_size)
undetermined_special_treat = special_treat_to_constant_value(positional, keyword, undetermined, special_treat)
def forward(self, *args):
return args[0].expand(self.new_size).clone()
return FuncAdapter(func, positional, keyword, undetermined, undetermined_special_treat)
c_node = node.key_node
inputs = list(c_node.inputs())
new_size = translate_list(inputs[1], speedup)
return ExpandModule(new_size)
def expandas_python(node, speedup):
class ExpandasModule(torch.nn.Module):
def forward(self, x, y):
return x.expand_as(y).clone()
return ExpandasModule()
trans_from_jit_to_python = {
'aten::add': add_python,
'aten::add_': add_python,
'aten::sub': sub_python,
'aten::sub_': sub_python,
'aten::mul': mul_python,
'aten::mul_': mul_python,
'aten::relu': relu_python,
'aten::relu_': relu_inplace_python,
'aten::sigmoid': sigmoid_python,
'aten::sigmoid_': sigmoid_python,
# tanh behaives like relu
'aten::tanh': relu_python,
'aten::tanh_': relu_python,
'aten::flatten': flatten_python,
'aten::mean': mean_python,
'aten::dropout': dropout_python,
trans_func_dict = {
'aten::slice': slice_python,
'aten::select': select_python,
'aten::size': size_python,
'aten::t': transpose_python,
'aten::transpose': transpose2_python,
'aten::Int': toint_python,
'aten::view': view_python,
'aten::reshape': reshape_python,
'aten::permute': permute_python,
'aten::matmul': matmul_python,
'aten::div': div_python,
'aten::floor_divide': floor_div_python,
'aten::softmax': softmax_python,
'aten::contiguous': contiguous_python,
'aten::gelu': gelu_python,
'aten::cat': cat_python,
'aten::avg_pool2d': avgpool2d_python,
'aten::max_pool2d': avgpool2d_python,
'aten::adaptive_avg_pool2d': adaptive_avgpool_python,
'aten::to': to_python,
'aten::type_as': typeas_python,
'aten::upsample_bilinear2d': upsample_bilinear2d_python,
'aten::upsample_nearest2d': upsample_nearest2d_python,
'aten::exp': exp_python,
'aten::squeeze': squeeze_python,
'aten::unsqueeze': unsqueeze_python,
'aten::ones': ones_python,
'aten::zeros': zeros_python,
'aten::rsub': rsub_python,
'aten::expand': expand_python,
'aten::Int': partial(generate_aten_to_python, torch._C._TensorBase.int),
'prim::TupleUnpack': tupleunpack_python,
'prim::ListUnpack': tupleunpack_python,
'prim::NumToTensor': num2tensor_python,
'prim::GetAttr': getattr_python,
'prim::Constant': constant_python,
'aten::constant_pad_nd': constant_pad_nd_python,
'aten::silu': silu_python,
'aten::expand_as': expandas_python
}
def init_add_functions(func_from: Union[ModuleType, Type[Any]]):
"""
Add function/method attributes from a module/class, to the trans_func_dict
Parameters
---------
func_from
The module/class include needed functions
def jit_to_python_function(node, speedup):
"""
Return a callable object to inference the mask according to the
node.op_type.
global trans_func_dict
new_trans_func_dict = dict()
for name in dir(func_from):
attr = getattr(func_from, name)
if callable(attr) and not name.startswith('__'):
new_trans_func_dict['aten::' + name] = partial(generate_aten_to_python, attr)
trans_func_dict = {**new_trans_func_dict, **trans_func_dict}
init_add_functions(torch._C._VariableFunctions)
init_add_functions(torch._C._nn)
init_add_functions(torch._C._TensorBase)
def jit_to_python_function(node: NodePyGroup, speedup: ModelSpeedup) -> FuncAdapter:
"""
Return a callable object to inference the mask according to the node.op_type.
Parameters
---------
node: NodeGroup
node
The target node to inference the mask
speedup: ModelSpeedup
speedup
The speedup object of the target model.
Returns
------
func: callable object(nn.Module/function)
func
Return the translated function that used to inference the mask
, if current op_type is not supported, then we return None.
"""
logger.debug(
'Translate C function %s into its python version', node.op_type)
if node.op_type not in trans_from_jit_to_python:
if node.op_type not in trans_func_dict:
logger.error(
'%s is not Supported! Please report an issue at https://github.com/microsoft/nni. Thanks~', node.op_type)
# return None to skip the mask inference for this node
return None
return trans_from_jit_to_python[node.op_type](node, speedup)
return trans_func_dict[node.op_type](node, speedup)
......@@ -61,7 +61,7 @@ class BackboneModel2(torch.nn.Module):
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(x.size(0), -1)
x = x.reshape(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import unittest
import torch
import torch.nn.functional as F
from nni.compression.pytorch.pruning import L1NormPruner
from nni.compression.pytorch.speedup import ModelSpeedup
from nni.algorithms.compression.v2.pytorch.utils import (
compute_sparsity_compact2origin,
compute_sparsity_mask2compact
)
class CondModel(torch.nn.Module):
"""
test for:
prim::If
"""
the_cond: bool
def __init__(self):
super().__init__()
self.the_cond = True
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.the_cond:
x = x + 0.00001
else:
x = x - 0.00001
self.the_cond = not self.the_cond
return x
class ASubModel(torch.nn.Module):
"""
test for:
sub model
"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + 0.00001
return x
class TorchModel1(torch.nn.Module):
"""
test for:
add, sub, mul, div, exp, matmul,
relu, gelu, tanh, silu, sigmod, softmax,
size, unsqueeze, flatten, cat, slice, reshape, transpose, t, select, permute, constant_pad_nd,
mean, avg_pool2d, max_pool2d, sum, adaptive_avg_pool2d,
to, Int, view,
type_as, expand_as, contiguous,
notes:
'floor_divide' have no backward, then not be tested
"""
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 6, 5, 1)
self.conv2 = torch.nn.Conv2d(6, 16, 5, 1)
self.fccond = torch.nn.Linear(16 * 4 * 4, 16 * 4 * 4)
self.fc1 = torch.nn.Linear(16 * 4 * 4, 120)
self.fc2 = torch.nn.Linear(120, 84)
self.fc3 = torch.nn.Linear(84, 10)
self.pool1 = torch.nn.MaxPool2d((2, 2))
self.pool2 = torch.nn.MaxPool2d((2, 2))
self.cond = torch.jit.script(CondModel())
self.asub = ASubModel()
def forward(self, x: torch.Tensor):
x = x.contiguous(memory_format=torch.channels_last)
x = torch._C._nn.upsample_bilinear2d(x, (28, 28), False)
x = torch._C._nn.upsample_nearest2d(x, (28, 28))
x = F.adaptive_avg_pool2d(x, (28, 28))
x = torch.exp(x)
x = torch.sigmoid(x)
x = torch.transpose(x, 1, 2)
x = torch.transpose(x, 1, 2)
x = F.avg_pool2d(x, 3, 1, padding=1)
x = F.max_pool2d(x, 3, 1, padding=1)
x = x.to(torch.float32)
x = self.conv1(x)
y1 = self.pool1(F.relu(x))
y2 = self.pool1(F.gelu(x))
x = y1 + y2
x = x + 0.00001
x = x * 1.00001
x = self.conv2(x)
y1 = self.pool2(F.silu(x))
y2 = self.pool2(torch.tanh(x))
x = y1 - y2
x = x - 0.00001
x = x / 1.00001
x = torch.permute(x, (0, 2, 3, 1))
x = torch.permute(x, (0, 2, 3, 1))
x = torch.permute(x, (0, 2, 3, 1))
x = torch.unsqueeze(x, dim=1)
x = torch.select(x, dim=1, index=0)
x = torch.unsqueeze(x, dim=1)
x = torch.mean(x, dim=1)
x = torch.unsqueeze(x, dim=1)
x = torch.sum(x, dim=1, dtype=torch.float32)
x = torch.unsqueeze(x, dim=1)
x = torch.squeeze(x, dim=1)
x = torch.flatten(x, 1)
x = x.reshape(x.shape)
x = x.view(-1, x.size(1))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.softmax(self.fc3(x), dim=1)
y1 = x[:,0:int(x.size(1)/2)]
y2 = x[:,int(x.size(1)/2):x.size(1)]
x = torch.cat((y1, y2), dim=1)
x = x.type_as(x)
x = x.expand_as(x)
x = torch.matmul(x, x.t())
x = torch.cat([x, x], dim=1)
# x = self.cond(x)
x = self.asub(x)
x = torch.constant_pad_nd(x, (1,1,1,1), 3.14159)
return x
class AutoConvTestCase(unittest.TestCase):
def test_l1norm_pruner(self):
model = TorchModel1()
dummy_input = torch.rand(3, 1, 28, 28)
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.5}]
pruner = L1NormPruner(model=model, config_list=config_list)
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
ModelSpeedup(model, dummy_input, masks).speedup_model()
real_sparsity_list = compute_sparsity_compact2origin(TorchModel1(), model, config_list)
print('sparsity_list:', sparsity_list)
assert 0.45 < sparsity_list[0]['total_sparsity'] < 0.55
print('real_sparsity_list:', real_sparsity_list)
assert 0.45 < real_sparsity_list[0]['total_sparsity'] < 0.75
print('the shape of output of the infer:', model(dummy_input).shape)
assert model(dummy_input).shape == torch.Size((5, 8))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment