Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
588f299b
Unverified
Commit
588f299b
authored
Aug 16, 2022
by
Louis-J
Committed by
GitHub
Aug 16, 2022
Browse files
feat(speedup): automatically convert op asts to callables (#4996)
parent
b2c31ca2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
481 additions
and
595 deletions
+481
-595
nni/compression/pytorch/speedup/jit_translate.py
nni/compression/pytorch/speedup/jit_translate.py
+315
-594
test/algo/compression/v1/test_model_speedup.py
test/algo/compression/v1/test_model_speedup.py
+1
-1
test/algo/compression/v2/test_auto_conv.py
test/algo/compression/v2/test_auto_conv.py
+165
-0
No files found.
nni/compression/pytorch/speedup/jit_translate.py
View file @
588f299b
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
__future__
import
annotations
from
types
import
ModuleType
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Type
,
Union
from
typing
import
TYPE_CHECKING
if
TYPE_CHECKING
:
# Only imports the below statements during type checking
from
nni.compression.pytorch.speedup
import
ModelSpeedup
from
nni.common.graph_utils
import
NodePyGroup
import
re
import
re
import
logging
import
logging
from
functools
import
partial
from
functools
import
partial
,
lru_cache
import
copy
import
copy
import
torch
import
torch
...
@@ -15,31 +24,24 @@ jitid_2_dtype = {4: torch.long, 6:torch.float32}
...
@@ -15,31 +24,24 @@ jitid_2_dtype = {4: torch.long, 6:torch.float32}
# to exclude partial
# to exclude partial
__all__
=
[
__all__
=
[
'adaptive_avgpool_python'
,
'add_python'
,
'avgpool2d_python'
,
'cat_python'
,
'contiguous_python'
,
'getattr_python'
,
'jit_to_python_function'
,
'num2tensor_python'
,
'parse_constant'
,
'slice_python'
,
'div_python'
,
'dropout_python'
,
'exp_python'
,
'flatten_python'
,
'floor_div_python'
,
'gelu_python'
,
'translate_list'
,
'tupleunpack_python'
,
'dtype_trans'
,
'memory_format_trans'
'getattr_python'
,
'jit_to_python_function'
,
'matmul_python'
,
'mean_python'
,
'mul_python'
,
'num2tensor_python'
,
'parse_constant'
,
'permute_python'
,
'relu_inplace_python'
,
'relu_python'
,
'reshape_python'
,
'select_python'
,
'sigmoid_python'
,
'size_python'
,
'slice_python'
,
'softmax_python'
,
'squeeze_python'
,
'to_python'
,
'toint_python'
,
'torch'
,
'trans_from_jit_to_python'
,
'translate_list'
,
'transpose2_python'
,
'transpose_python'
,
'tupleunpack_python'
,
'typeas_python'
,
'unsqueeze_python'
,
'upsample_bilinear2d_python'
,
'view_python'
]
]
def
translate_list
(
list_node
:
torch
.
_C
.
Value
,
speedup
:
ModelSpeedup
=
None
)
->
List
:
def
translate_list
(
list_node
,
speedup
=
None
):
"""
"""
Get the list of values from the list construct node.
Get the list of values from the list construct node.
Parameters
Parameters
----------
----------
list_node
: Torch.C.Value
list_node
The cpp node of the target list.
The cpp node of the target list.
speedup
: ModuleSpeed
speedup
The Module speedup module.
The Module speedup module.
Returns
Returns
-------
-------
values
: list
values
The list of values in the target cpp list node.
The list of values in the target cpp list node.
"""
"""
# the node that create the list
# the node that create the list
...
@@ -52,27 +54,26 @@ def translate_list(list_node, speedup=None):
...
@@ -52,27 +54,26 @@ def translate_list(list_node, speedup=None):
if
speedup
is
not
None
and
debugName
in
speedup
.
internal_result
:
if
speedup
is
not
None
and
debugName
in
speedup
.
internal_result
:
# this value is the result of the other nodes, such as
# this value is the result of the other nodes, such as
# ate::size
# ate::size
values
.
append
(
speedup
.
internal_result
[
debugName
]
.
item
()
)
values
.
append
(
speedup
.
internal_result
[
debugName
])
else
:
else
:
# if the corresponding value is a constant
# if the corresponding value is a constant
values
.
append
(
_i
.
toIValue
())
values
.
append
(
_i
.
toIValue
())
return
values
return
values
def
parse_constant
(
cvalue
:
torch
.
_C
.
Value
,
speedup
:
ModelSpeedup
)
->
Any
:
def
parse_constant
(
cvalue
,
speedup
):
"""
"""
Parse the constant values from this Node
Parse the constant values from this Node
Parameters
Parameters
----------
----------
cvalue
: Torch.C.Value
cvalue
The cpp node of the target constant value.
The cpp node of the target constant value.
speedup
: ModelSpeedup
speedup
The Model speedup module.
The Model speedup module.
Returns
Returns
-------
-------
value
: int/float/tensor
value
The constant values parsed from the node.
The constant values parsed from the node.
"""
"""
logger
.
debug
(
'Try to parse the constant value: %s'
,
cvalue
.
debugName
())
logger
.
debug
(
'Try to parse the constant value: %s'
,
cvalue
.
debugName
())
...
@@ -85,245 +86,13 @@ def parse_constant(cvalue, speedup):
...
@@ -85,245 +86,13 @@ def parse_constant(cvalue, speedup):
inputs
=
op_node
.
inputs
()
inputs
=
op_node
.
inputs
()
input_values
=
[
parse_constant
(
_i
,
speedup
)
for
_i
in
inputs
]
input_values
=
[
parse_constant
(
_i
,
speedup
)
for
_i
in
inputs
]
func
=
trans_from_jit_to_python
[
op_node
.
kind
()](
op_node
,
speedup
)
if
op_node
.
kind
()
not
in
trans_func_dict
:
return
func
(
*
input_values
)
raise
RuntimeError
(
'Unsupported function op node type: {}'
.
format
(
op_node
.
kind
()))
def
dropout_python
(
node
,
speedup
):
return
torch
.
dropout
def
flatten_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
start_dim
=
inputs
[
1
].
toIValue
()
end_dim
=
inputs
[
2
].
toIValue
()
new_flatten
=
partial
(
torch
.
flatten
,
start_dim
=
start_dim
,
end_dim
=
end_dim
)
return
new_flatten
def
relu_inplace_python
(
node
,
speedup
):
return
torch
.
relu_
def
relu_python
(
node
,
speedup
):
return
torch
.
relu
def
sigmoid_python
(
node
,
speedup
):
return
torch
.
sigmoid
def
mean_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim_list
=
translate_list
(
inputs
[
1
],
speedup
)
keep_dim
=
inputs
[
2
].
toIValue
()
new_mean
=
partial
(
torch
.
mean
,
dim
=
tuple
(
dim_list
),
keepdim
=
keep_dim
)
return
new_mean
def
add_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
constant
=
None
for
i
in
range
(
2
):
input_i
=
inputs
[
i
]
debug_name
=
input_i
.
debugName
()
if
debug_name
not
in
speedup
.
internal_result
:
# this input is a constant value
# TODO: what if this input is a constant tensor
if
input_i
.
toIValue
()
is
not
None
:
constant
=
parse_constant
(
input_i
,
speedup
)
break
if
constant
is
None
:
return
torch
.
add
else
:
new_add
=
partial
(
torch
.
add
,
constant
)
return
new_add
def
sub_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
constant
=
[
None
,
None
]
for
i
in
range
(
2
):
input_i
=
inputs
[
i
]
debug_name
=
input_i
.
debugName
()
if
debug_name
not
in
speedup
.
internal_result
:
# this input is a constant value
# TODO: what if this input is a constant tensor
if
input_i
.
toIValue
()
is
not
None
:
constant
[
i
]
=
parse_constant
(
input_i
,
speedup
)
break
if
constant
[
0
]
is
None
and
constant
[
1
]
is
None
:
new_sub
=
torch
.
sub
elif
constant
[
0
]
is
not
None
:
new_sub
=
partial
(
torch
.
sub
,
input
=
constant
)
else
:
new_sub
=
partial
(
torch
.
sub
,
other
=
constant
)
return
new_sub
def
floor_div_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
divisor
=
inputs
[
1
]
constant
=
None
if
divisor
.
debugName
()
not
in
speedup
.
internal_result
:
# divisor is a constant value/tensor
constant
=
parse_constant
(
divisor
,
speedup
)
if
constant
is
None
:
return
torch
.
floor_divide
else
:
new_op
=
partial
(
torch
.
floor_divide
,
other
=
constant
)
return
new_op
def
mul_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
constant
=
None
for
i
in
range
(
2
):
input_i
=
inputs
[
i
]
debug_name
=
input_i
.
debugName
()
if
debug_name
not
in
speedup
.
internal_result
:
constant
=
parse_constant
(
input_i
,
speedup
)
# both two inputs cannot be constants at the same time
break
if
constant
is
None
:
return
torch
.
mul
else
:
new_mul
=
partial
(
torch
.
mul
,
constant
)
return
new_mul
def
transpose_python
(
node
,
speedup
):
return
torch
.
t
def
transpose2_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim_1
=
inputs
[
1
].
toIValue
()
dim_2
=
inputs
[
2
].
toIValue
()
new_transpose
=
partial
(
torch
.
transpose
,
dim0
=
dim_1
,
dim1
=
dim_2
)
return
new_transpose
def
matmul_python
(
node
,
speedup
):
return
torch
.
matmul
def
div_python
(
node
,
speedup
):
# The second input parameter of torch.div can be a
# tensor or a constant, if it is a constant, we need
# to return
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
if
inputs
[
1
].
debugName
()
in
speedup
.
internal_result
:
# the second input parameters is the output of the other
# nodes
return
torch
.
div
else
:
other
=
inputs
[
1
].
toIValue
()
new_div
=
partial
(
torch
.
div
,
other
=
other
)
return
new_div
def
softmax_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim
=
inputs
[
1
].
toIValue
()
new_softmax
=
partial
(
torch
.
softmax
,
dim
=
dim
)
return
new_softmax
def
contiguous_python
(
node
,
speedup
):
class
contiguousModule
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
.
contiguous
().
clone
()
return
contiguousModule
()
def
gelu_python
(
node
,
speedup
):
return
torch
.
nn
.
GELU
()
def
silu_python
(
node
,
speedup
):
return
torch
.
nn
.
SiLU
()
def
avgpool2d_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
kernel_size
=
translate_list
(
inputs
[
1
],
speedup
)
stride
=
translate_list
(
inputs
[
2
],
speedup
)
padding
=
translate_list
(
inputs
[
3
],
speedup
)
new_avgpool
=
partial
(
torch
.
nn
.
functional
.
avg_pool2d
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
)
return
new_avgpool
def
adaptive_avgpool_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
output_size
=
translate_list
(
inputs
[
1
],
speedup
)
new_avgpool
=
torch
.
nn
.
AdaptiveAvgPool2d
(
output_size
)
return
new_avgpool
def
tupleunpack_python
(
node
,
speedup
):
# Note: tuple unpack should only exists at the
# the end of the model, and is no need to replace/propagate mask
return
None
def
num2tensor_python
(
node
,
speedup
):
return
torch
.
nn
.
Identity
()
def
exp_python
(
node
,
speedup
):
return
torch
.
exp
def
squeeze_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim
=
None
if
len
(
inputs
)
>
1
:
dim
=
parse_constant
(
inputs
[
1
],
speedup
)
new_squeeze
=
partial
(
torch
.
squeeze
,
dim
=
dim
)
return
new_squeeze
def
unsqueeze_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim
=
parse_constant
(
inputs
[
1
],
speedup
)
new_unsqueeze
=
partial
(
torch
.
unsqueeze
,
dim
=
dim
)
return
new_unsqueeze
def
constant_pad_nd_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
pad
=
translate_list
(
inputs
[
1
],
speedup
)
value
=
parse_constant
(
inputs
[
2
],
speedup
)
new_constant_pad_nd
=
partial
(
torch
.
nn
.
functional
.
pad
,
pad
=
pad
,
value
=
value
)
return
new_constant_pad_nd
##########################################################
# Split Line
# Following module/functions cannot be translated into a
# single function, so we use torch.nn.Module to wrap the
# the core function, and return the torch.nn.Module instead
##########################################################
func
=
trans_func_dict
[
op_node
.
kind
()](
op_node
,
speedup
)
return
func
(
*
input_values
)
def
slice_python
(
node
,
s
peedup
):
def
slice_python
(
node
:
NodePyGroup
,
speedup
:
ModelS
peedup
):
class
SliceMoudle
(
torch
.
nn
.
Module
):
class
SliceMoudle
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
sliceobj
):
def
__init__
(
self
,
sliceobj
):
super
(
SliceMoudle
,
self
).
__init__
()
super
(
SliceMoudle
,
self
).
__init__
()
...
@@ -368,102 +137,38 @@ def slice_python(node, speedup):
...
@@ -368,102 +137,38 @@ def slice_python(node, speedup):
else
:
else
:
return
SliceMoudle
(
tuple
(
slice_list
))
return
SliceMoudle
(
tuple
(
slice_list
))
def
cat_python
(
node
:
NodePyGroup
,
speedup
:
ModelSpeedup
):
def
select_python
(
node
,
speedup
):
class
CatModule
(
torch
.
nn
.
Module
):
class
SelectModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
cat_dim
):
def
__init__
(
self
,
dim
,
index
):
super
(
CatModule
,
self
).
__init__
()
super
(
SelectModule
,
self
).
__init__
()
self
.
cat_dim
=
cat_dim
self
.
dim
=
copy
.
deepcopy
(
dim
)
self
.
index
=
copy
.
deepcopy
(
index
)
def
forward
(
self
,
x
):
return
x
.
select
(
self
.
dim
,
self
.
index
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim
=
inputs
[
1
].
toIValue
()
index
=
inputs
[
2
].
toIValue
()
return
SelectModule
(
dim
,
index
)
def
size_python
(
node
,
speedup
):
# return None
class
SizeMoudle
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
sizedim
):
super
(
SizeMoudle
,
self
).
__init__
()
self
.
sizedim
=
sizedim
def
forward
(
self
,
x
):
return
torch
.
as_tensor
([
x
.
size
(
self
.
sizedim
)],
dtype
=
torch
.
long
)
# return torch.tensor(x.size(self.sizedim))
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
size_dim
=
inputs
[
1
].
toIValue
()
return
SizeMoudle
(
size_dim
)
def
toint_python
(
node
,
speedup
):
class
ToIntModule
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
.
to
(
torch
.
int
)
return
ToIntModule
()
def
view_python
(
node
,
speedup
):
class
ViewModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
shape
):
super
(
ViewModule
,
self
).
__init__
()
self
.
shape
=
shape
logger
.
info
(
'View Module output size: %s'
,
str
(
self
.
shape
))
def
forward
(
self
,
*
args
):
def
forward
(
self
,
*
args
):
return
args
[
0
].
view
(
self
.
shape
)
return
torch
.
cat
(
args
,
dim
=
self
.
cat_dim
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
shape
=
translate_list
(
inputs
[
1
],
speedup
)
return
ViewModule
(
shape
)
def
reshape_python
(
node
,
speedup
):
class
ReshapeModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
shape
):
super
(
ReshapeModule
,
self
).
__init__
()
self
.
shape
=
shape
logger
.
info
(
'Reshape Module output size: %s'
,
str
(
self
.
shape
))
def
forward
(
self
,
*
args
):
return
args
[
0
].
reshape
(
self
.
shape
)
c_node
=
node
.
key_node
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
inputs
=
list
(
c_node
.
inputs
())
shape
=
translate_list
(
inputs
[
1
],
speedup
)
dim
=
inputs
[
1
].
toIValue
()
return
ReshapeModule
(
shape
)
return
CatModule
(
dim
)
def
permute_python
(
node
,
speedup
):
class
PermuteModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dimlist
):
super
(
PermuteModule
,
self
).
__init__
()
# deepcopy the values here, because the following randomize operation
# will change the value of the dimlist
self
.
dimlist
=
copy
.
deepcopy
(
dimlist
)
def
forward
(
self
,
x
):
def
tupleunpack_python
(
_node
:
NodePyGroup
,
_speedup
:
ModelSpeedup
)
->
Optional
[
Callable
]:
return
x
.
permute
(
self
.
dimlist
)
# Note: tuple unpack should only exists at the
c_node
=
node
.
key_node
# the end of the model, and is no need to replace/propagate mask
inputs
=
list
(
c_node
.
inputs
())
return
None
dim_list
=
translate_list
(
inputs
[
1
],
speedup
)
return
PermuteModule
(
dim_list
)
def
num2tensor_python
(
_node
:
NodePyGroup
,
_speedup
:
ModelSpeedup
):
return
torch
.
nn
.
Identity
()
def
getattr_python
(
node
,
s
peedup
):
def
getattr_python
(
node
:
NodePyGroup
,
_speedup
:
ModelS
peedup
):
"""
"""
Note: Ops started with Prim:: is not taken as the key node,
Note: Ops started with Prim:: is not taken as the key node,
so we directly pass the Cpp node into this funciton.
so we directly pass the Cpp node into this funciton.
Parameters
Parameters
----------
----------
node
: torch._C.Node
node
The cpp node of prim::Getattr
The cpp node of prim::Getattr
speedup
: ModelSpeedup
speedup
The corresponding speedup object.
The corresponding speedup object.
"""
"""
class
GetModule
(
torch
.
nn
.
Module
):
class
GetModule
(
torch
.
nn
.
Module
):
...
@@ -482,316 +187,332 @@ def getattr_python(node, speedup):
...
@@ -482,316 +187,332 @@ def getattr_python(node, speedup):
assert
len
(
key_words
)
==
1
assert
len
(
key_words
)
==
1
return
GetModule
(
key_words
[
0
])
return
GetModule
(
key_words
[
0
])
def
constant_python
(
node
,
speedup
)
:
class
FuncAdapter
:
"""
"""
get the constant value of constant operator node.
A function adapter which can reorder arguments.
It can be initialate with constant argument, and positions of each non-constant
argument. When called, it can put arguments into correct position, then call the
function.
Parame
te
r
s
Attribu
tes
----------
----------
node: torch._C.Node
func
The cpp node of prim::Getattr
The function or method to be called.
speedup: ModelSpeedup
positional
The corresponding speedup object.
Positional arguments values. The placeholder is None if it's non-constant.
"""
keyword
class
ConstantModule
(
torch
.
nn
.
Module
):
Keyword arguments values. The placeholder is None if it's non-constant.
def
__init__
(
self
,
constant
):
undetermined
super
(
ConstantModule
,
self
).
__init__
()
A list of the right positions of arguments.
self
.
constant
=
constant
Position is an int in positional or a str in keyword.
def
forward
(
self
):
special_treat
return
self
.
constant
A Dict of the positions and methods.
The values of these positions should be treat by those methods.
assert
node
.
kind
()
==
'prim::Constant'
pattern
=
'\[value=(.*?)\]'
key_words
=
re
.
findall
(
pattern
,
str
(
node
))
if
len
(
key_words
)
==
0
:
return
ConstantModule
(
None
)
assert
len
(
key_words
)
==
1
# parse the constant value
value
=
key_words
[
0
]
if
value
.
startswith
(
"
\"
"
):
value
=
torch
.
device
(
value
[
1
:
-
1
])
elif
value
.
startswith
(
'{'
):
# TODO Support set values in the future
value
=
set
()
elif
'.'
in
value
:
# float value
value
=
float
(
value
)
else
:
# integer value
value
=
int
(
value
)
return
ConstantModule
(
value
)
def
upsample_bilinear2d_python
(
node
,
speedup
):
class
UpsampleModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
size_list
,
scale_list
):
super
(
UpsampleModule
,
self
).
__init__
()
self
.
size_list
=
size_list
self
.
scale_list
=
scale_list
def
forward
(
self
,
*
args
):
"""
The first input of args is the target tensor to upsample
, the following parameters is useless, because we already
get the size_list and the scale_list by parsing the cpp_nodes.
"""
return
torch
.
nn
.
functional
.
upsample_bilinear
(
args
[
0
],
size
=
self
.
size_list
,
scale_factor
=
self
.
scale_list
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
size_list_node
=
inputs
[
1
].
node
()
scale_list_node
=
inputs
[
3
].
node
()
size_list
=
None
scale_list
=
None
if
size_list_node
.
kind
()
==
'prim::ListConstruct'
:
"""
size_list
=
translate_list
(
inputs
[
1
],
speedup
)
if
scale_list_node
.
kind
()
==
'prim::ListConstruct'
:
scale_list
=
translate_list
(
inputs
[
3
],
speedup
)
return
UpsampleModule
(
size_list
,
scale_list
)
def
__init__
(
self
,
func
:
Callable
,
positional
:
List
[
Any
],
keyword
:
Dict
[
str
,
Any
],
undetermined
:
List
[
Union
[
int
,
str
]],
special_treat
:
Dict
[
Union
[
int
,
str
],
Callable
]):
if
not
callable
(
func
):
raise
TypeError
(
'the "func" argument must be callable'
)
self
.
func
=
func
self
.
positional
=
positional
self
.
keyword
=
keyword
self
.
undetermined
=
undetermined
self
.
special_treat
=
special_treat
def
__call__
(
self
,
/
,
*
args
):
assert
len
(
args
)
>=
len
(
self
.
undetermined
)
if
len
(
args
)
>
len
(
self
.
undetermined
):
logger
.
warning
(
'throw some args away when calling the function "%s"'
,
self
.
func
.
__name__
)
for
i
,
p
in
enumerate
(
self
.
undetermined
):
v
=
args
[
i
]
if
isinstance
(
p
,
int
):
self
.
positional
[
p
]
=
v
else
:
self
.
keyword
[
p
]
=
v
for
p
,
fs
in
self
.
special_treat
.
items
():
if
isinstance
(
p
,
int
):
for
f
in
fs
:
self
.
positional
[
p
]
=
f
(
self
.
positional
[
p
])
else
:
for
f
in
fs
:
self
.
keyword
[
p
]
=
f
(
self
.
keyword
[
p
])
result
=
self
.
func
(
*
self
.
positional
,
**
self
.
keyword
)
if
isinstance
(
result
,
int
):
# turn result of 'size' into tensor
result
=
torch
.
as_tensor
([
result
],
dtype
=
torch
.
long
)
return
result
# There are some types that will be convert into enums after jit.
# So we should recover them back:
# device, dtype, layout, memory_format, qscheme, qengine, dispatchkey
enum_to_dtype_names
=
{
0
:
'uint8'
,
1
:
'int8'
,
2
:
'int16'
,
3
:
'int32'
,
4
:
'int64'
,
5
:
'float16'
,
6
:
'float32'
,
7
:
'float64'
,
8
:
'complex32'
,
9
:
'complex64'
,
10
:
'complex128'
,
11
:
'bool'
,
12
:
'qint8'
,
13
:
'quint8'
,
14
:
'qint32'
,
15
:
'bfloat16'
,
16
:
'quint4x2'
,
17
:
'quint2x4'
,
}
def
upsample_nearest2d_python
(
node
,
speedup
):
enum_to_dtype_dict
=
{}
class
UpsampleModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
size_list
,
scale_list
):
super
(
UpsampleModule
,
self
).
__init__
()
self
.
size_list
=
size_list
self
.
scale_list
=
scale_list
def
forward
(
self
,
*
args
):
for
enum_value
,
dtype_name
in
enum_to_dtype_names
.
items
():
"""
if
hasattr
(
torch
,
dtype_name
):
The first input of args is the target tensor to upsample
enum_to_dtype_dict
[
enum_value
]
=
getattr
(
torch
,
dtype_name
)
, the following parameters is useless, because we already
get the size_list and the scale_list by parsing the cpp_nodes.
"""
return
torch
.
nn
.
functional
.
upsample_nearest
(
args
[
0
],
size
=
self
.
size_list
,
scale_factor
=
self
.
scale_list
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
size_list_node
=
inputs
[
1
].
node
()
scale_list_node
=
inputs
[
2
].
node
()
size_list
=
None
scale_list
=
None
if
size_list_node
.
kind
()
==
'prim::ListConstruct'
:
def
dtype_trans
(
ivalue
:
Union
[
int
,
torch
.
dtype
])
:
size_list
=
translate_list
(
inputs
[
1
],
speedup
)
"""
if
scale_list_node
.
kind
()
==
'prim::ListConstruct'
:
Special process for dtype.
scale_list
=
translate_list
(
inputs
[
2
],
speedup
)
Torch will transform dtype to an enum in cpp, so the value of dtype we get in jit is an int.
return
UpsampleModule
(
size_list
,
scale_list
)
This function is used to recover the int to torch.dtype in python.
Parameters
----------
ivalue
The value of dtype or method to be recovered.
def
typeas_python
(
node
,
speedup
):
"""
currently only support type_as float.
TODO: support more types in the type_as, need to figure out
how to get the scalar type from torch._C.TensorType.
"""
"""
class
TypeasModule
(
torch
.
nn
.
Module
):
if
ivalue
is
None
or
isinstance
(
ivalue
,
torch
.
dtype
):
def
__init__
(
self
,
dtype
=
torch
.
float
):
return
ivalue
self
.
example
=
torch
.
zeros
(
1
,
dtype
=
dtype
)
elif
isinstance
(
ivalue
,
int
):
if
ivalue
in
enum_to_dtype_dict
:
return
enum_to_dtype_dict
[
ivalue
]
raise
TypeError
(
'No torch.dtype corresponding to the value "%s"'
,
ivalue
)
enum_to_memory_format_dict
=
{
0
:
torch
.
contiguous_format
,
1
:
torch
.
preserve_format
,
2
:
torch
.
channels_last
,
3
:
torch
.
channels_last_3d
,
}
def
forward
(
self
,
x
):
def
memory_format_trans
(
ivalue
:
Union
[
int
,
torch
.
memory_format
]):
return
x
.
type_as
(
self
.
example
)
"""
return
TypeasModule
()
Special process for memory_format.
Torch will transform memory_format to an enum in cpp, so the value of memory_format we get in jit is an int.
This function is used to recover the int to torch.memory_format in python.
Parameters
----------
ivalue
The value of memory_format or method to be recovered.
def
to_python
(
node
,
speedup
):
"""
# for the time being, only device parameters are supported
if
ivalue
is
None
or
isinstance
(
ivalue
,
torch
.
memory_format
):
class
ToModule
(
torch
.
nn
.
Module
):
return
ivalue
def
__init__
(
self
,
device
,
dtype
):
elif
isinstance
(
ivalue
,
int
):
super
(
ToModule
,
self
).
__init__
()
global
enum_to_memory_format_dict
self
.
device
=
device
if
ivalue
in
enum_to_memory_format_dict
:
self
.
dtype
=
dtype
return
enum_to_memory_format_dict
[
ivalue
]
def
forward
(
self
,
x
):
raise
TypeError
(
'No torch.memory_format corresponding to the value "%s"'
,
ivalue
)
return
x
.
to
(
device
,
dtype
=
self
.
dtype
)
special_treat_dict
=
{
'dtype'
:
dtype_trans
,
'memory_format'
:
memory_format_trans
,
}
c_node
=
node
.
key_node
schema_fix_dict
=
{
inputs
=
list
(
c_node
.
inputs
())
# functinon 'to', 'randint', and 'sparse_coo_tensor' has different schema between python and c++.
in_debugname
=
inputs
[
0
].
debugName
()
# https://pytorch.org/docs/stable/jit_unsupported.html#ops-with-divergent-schemas-between-torch-python
# device of the input tensor
"""aten::to.device(Tensor(a) self, Device device, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Ten
device
=
speedup
.
internal_result
[
in_debugname
].
device
sor(a))"""
:
"""aten::to.device(Tensor(a) self, Device device, int dtype, bool non_blocking=False, bool copy=False, *, int? memory_format=None)
-> (Tensor(a))"""
,
'aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))'
:
'aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, *, int? memory_format=None) -> (Tensor(a))'
,
'aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))'
:
'aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, *, int? memory_format=None) -> (Tensor(a))'
,
# todo: are the arguments 'pin_memory' and 'requires_grad' related?
# functions in the python have only 'requires_grad' and functions in the aten have only 'pin_memory'
# 'aten::randint(int high, int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)',
# """aten::randint.generator(int high, int[] size, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, boo
# l? pin_memory=None) -> (Tensor)""",
# """aten::randint.low(int low, int high, int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None)
# -> (Tensor)""",
# """aten::randint.low_generator(int low, int high, int[] size, *, Generator? generator, int? dtype=None, int? layout=None, Device? dev
# ice=None, bool? pin_memory=None) -> (Tensor)""",
# """aten::sparse_coo_tensor.size(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=False) -> (Te
# nsor)""",
# """aten::sparse_coo_tensor.indices(Tensor indices, Tensor values, *, int? dtype=None, int? layout=None, Device? device=None, bool? pi
# n_memory=None) -> (Tensor)""",
# """aten::sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, int? dtype=None, int? layout=None, Device? devi
# ce=None, bool? pin_memory=None) -> (Tensor"""'
}
@
lru_cache
(
maxsize
=
256
)
def
parse_aten_schema
(
schema
:
str
):
"""
Parse the schema, to positional_num and keyword_list, and detect if the argument should be specially treated.
"""
if
schema
in
schema_fix_dict
:
schema
=
schema_fix_dict
[
schema
]
for
_
,
_node
in
enumerate
(
inputs
[
1
:]):
positional_num
=
0
val
=
parse_constant
(
_node
,
speedup
)
keyword_list
=
list
()
if
isinstance
(
val
,
torch
.
device
):
special_treat
=
dict
()
# for dtype and memory_format trans now
device
=
val
dtype
=
jitid_2_dtype
[
parse_constant
(
inputs
[
1
],
speedup
)]
return
ToModule
(
device
,
dtype
)
for
arg
in
torch
.
_C
.
parse_schema
(
schema
).
arguments
:
if
not
arg
.
kwarg_only
:
key
=
positional_num
positional_num
+=
1
else
:
key
=
arg
.
name
keyword_list
.
append
(
key
)
def
cat_python
(
node
,
speedup
)
:
if
arg
.
name
in
special_treat_dict
:
class
CatModule
(
torch
.
nn
.
Module
)
:
if
key
not
in
special_treat
:
def
__init__
(
self
,
cat_dim
):
special_treat
[
key
]
=
[
special_treat_dict
[
arg
.
name
]]
super
(
CatModule
,
self
).
__init__
()
else
:
self
.
cat_dim
=
cat_dim
special_treat
[
key
].
append
(
special_treat_dict
[
arg
.
name
])
def
forward
(
self
,
*
args
):
return
positional_num
,
keyword_list
,
special_treat
return
torch
.
cat
(
args
,
dim
=
self
.
cat_dim
)
c_node
=
node
.
key_node
def
parse_input_value
(
speedup
:
ModelSpeedup
,
input_nodes
:
List
[
torch
.
_C
.
Node
],
positional_num
:
int
,
keyword_list
:
List
[
str
]):
inputs
=
list
(
c_node
.
inputs
())
"""
dim
=
inputs
[
1
].
toIValue
()
translate inputs, to constant positional arguments, constant keyword arguments, and undetermined positions
return
CatModule
(
dim
)
"""
positional
=
list
()
keyword
=
dict
()
undetermined
=
list
()
for
ainput
in
input_nodes
:
if
ainput
.
node
().
kind
()
==
'prim::ListConstruct'
:
arg
=
translate_list
(
ainput
,
speedup
)
elif
ainput
.
node
().
kind
()
==
'prim::Constant'
:
arg
=
ainput
.
toIValue
()
else
:
assert
'aten::'
in
ainput
.
node
().
kind
()
or
'prim::'
in
ainput
.
node
().
kind
()
if
len
(
positional
)
<
positional_num
:
undetermined
.
append
(
len
(
positional
))
else
:
undetermined
.
append
(
keyword_list
[
positional_num
-
len
(
positional
)])
arg
=
None
if
len
(
positional
)
<
positional_num
:
positional
.
append
(
arg
)
else
:
keyword
[
keyword_list
[
positional_num
-
len
(
positional
)]]
=
arg
return
positional
,
keyword
,
undetermined
def
special_treat_to_constant_value
(
positional
:
List
,
keyword
:
Dict
[
str
],
undetermined
:
List
[
Union
[
int
,
str
]],
special_treat
:
Dict
[
Union
[
int
,
str
],
Callable
]):
"""
if any argument with special_treat is not in undetermined, do the treat
"""
undetermined_special_treat
=
dict
()
for
p
,
fs
in
special_treat
.
items
():
if
p
in
undetermined
:
undetermined_special_treat
[
p
]
=
fs
elif
isinstance
(
p
,
int
):
for
f
in
fs
:
positional
[
p
]
=
f
(
positional
[
p
])
else
:
for
f
in
fs
:
keyword
[
p
]
=
f
(
keyword
[
p
])
return
undetermined_special_treat
def
ones_python
(
node
,
speedup
):
def
generate_aten_to_python
(
func
:
Callable
,
node
:
NodePyGroup
,
speedup
:
ModelSpeedup
)
->
FuncAdapter
:
class
OnesModule
(
torch
.
nn
.
Module
):
"""
def
__init__
(
self
,
out_size
,
dtype_id
,
device
,
require_grad
):
parse a Return a callable object to inference the mask according to the node.op_type.
super
(
OnesModule
,
self
).
__init__
()
self
.
out_size
=
out_size
self
.
device
=
device
self
.
require_grad
=
require_grad
self
.
dtype
=
jitid_2_dtype
[
dtype_id
]
def
forward
(
self
,
*
args
):
Parameters
return
torch
.
ones
(
size
=
self
.
out_size
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
requires_grad
=
self
.
require_grad
)
---------
func
The torch function one-to-one correspondence with the node.
node
The target node to inference the mask
speedup
The speedup object of the target model.
Returns
------
func
Return the translated function that used to inference the mask
, if current op_type is not supported, then we return None.
"""
c_node
=
node
.
key_node
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
output_shape
=
translate_list
(
inputs
[
0
],
speedup
)
dtype_id
=
parse_constant
(
inputs
[
1
],
speedup
)
# layout = parse_constant(inputs[2], speedup)
device
=
parse_constant
(
inputs
[
3
],
speedup
)
require_grad
=
parse_constant
(
inputs
[
4
],
speedup
)
return
OnesModule
(
output_shape
,
dtype_id
,
device
,
require_grad
)
def
zeros_python
(
node
,
speedup
):
class
ZerosModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
out_size
,
dtype_id
,
device
,
require_grad
):
super
(
ZerosModule
,
self
).
__init__
()
self
.
out_size
=
out_size
self
.
device
=
device
self
.
require_grad
=
require_grad
self
.
dtype
=
jitid_2_dtype
[
dtype_id
]
def
forward
(
self
,
*
args
):
schema
=
c_node
.
schema
()
return
torch
.
zeros
(
size
=
self
.
out_size
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
requires_grad
=
self
.
require_grad
)
positional_num
,
keyword_list
,
special_treat
=
parse_aten_schema
(
schema
)
c_node
=
node
.
key_node
input_nodes
=
list
(
c_node
.
inputs
())
inputs
=
list
(
c_node
.
inputs
())
positional
,
keyword
,
undetermined
=
parse_input_value
(
speedup
,
input_nodes
,
positional_num
,
keyword_list
)
output_shape
=
translate_list
(
inputs
[
0
],
speedup
)
dtype_id
=
parse_constant
(
inputs
[
1
],
speedup
)
# layout = parse_constant(inputs[2], speedup)
device
=
parse_constant
(
inputs
[
3
],
speedup
)
require_grad
=
parse_constant
(
inputs
[
4
],
speedup
)
return
ZerosModule
(
output_shape
,
dtype_id
,
device
,
require_grad
)
def
rsub_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
constant
=
None
other_name
=
inputs
[
1
].
debugName
()
alpha
=
parse_constant
(
inputs
[
2
],
speedup
)
if
other_name
not
in
speedup
.
internal_result
:
constant
=
parse_constant
(
inputs
[
1
],
speedup
)
if
constant
is
None
:
return
torch
.
sub
()
else
:
new_sub
=
partial
(
torch
.
sub
,
other
=
constant
,
alpha
=
alpha
)
return
new_sub
def
expand_python
(
node
,
speedup
):
undetermined_special_treat
=
special_treat_to_constant_value
(
positional
,
keyword
,
undetermined
,
special_treat
)
class
ExpandModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
new_size
):
super
(
ExpandModule
,
self
).
__init__
()
# need deepcopy when the input is size-related
self
.
new_size
=
copy
.
deepcopy
(
new_size
)
def
forward
(
self
,
*
args
):
return
FuncAdapter
(
func
,
positional
,
keyword
,
undetermined
,
undetermined_special_treat
)
return
args
[
0
].
expand
(
self
.
new_size
).
clone
()
c_node
=
node
.
key_node
trans_func_dict
=
{
inputs
=
list
(
c_node
.
inputs
())
new_size
=
translate_list
(
inputs
[
1
],
speedup
)
return
ExpandModule
(
new_size
)
def
expandas_python
(
node
,
speedup
):
class
ExpandasModule
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
return
x
.
expand_as
(
y
).
clone
()
return
ExpandasModule
()
trans_from_jit_to_python
=
{
'aten::add'
:
add_python
,
'aten::add_'
:
add_python
,
'aten::sub'
:
sub_python
,
'aten::sub_'
:
sub_python
,
'aten::mul'
:
mul_python
,
'aten::mul_'
:
mul_python
,
'aten::relu'
:
relu_python
,
'aten::relu_'
:
relu_inplace_python
,
'aten::sigmoid'
:
sigmoid_python
,
'aten::sigmoid_'
:
sigmoid_python
,
# tanh behaives like relu
'aten::tanh'
:
relu_python
,
'aten::tanh_'
:
relu_python
,
'aten::flatten'
:
flatten_python
,
'aten::mean'
:
mean_python
,
'aten::dropout'
:
dropout_python
,
'aten::slice'
:
slice_python
,
'aten::slice'
:
slice_python
,
'aten::select'
:
select_python
,
'aten::size'
:
size_python
,
'aten::t'
:
transpose_python
,
'aten::transpose'
:
transpose2_python
,
'aten::Int'
:
toint_python
,
'aten::view'
:
view_python
,
'aten::reshape'
:
reshape_python
,
'aten::permute'
:
permute_python
,
'aten::matmul'
:
matmul_python
,
'aten::div'
:
div_python
,
'aten::floor_divide'
:
floor_div_python
,
'aten::softmax'
:
softmax_python
,
'aten::contiguous'
:
contiguous_python
,
'aten::gelu'
:
gelu_python
,
'aten::cat'
:
cat_python
,
'aten::cat'
:
cat_python
,
'aten::avg_pool2d'
:
avgpool2d_python
,
'aten::Int'
:
partial
(
generate_aten_to_python
,
torch
.
_C
.
_TensorBase
.
int
),
'aten::max_pool2d'
:
avgpool2d_python
,
'aten::adaptive_avg_pool2d'
:
adaptive_avgpool_python
,
'aten::to'
:
to_python
,
'aten::type_as'
:
typeas_python
,
'aten::upsample_bilinear2d'
:
upsample_bilinear2d_python
,
'aten::upsample_nearest2d'
:
upsample_nearest2d_python
,
'aten::exp'
:
exp_python
,
'aten::squeeze'
:
squeeze_python
,
'aten::unsqueeze'
:
unsqueeze_python
,
'aten::ones'
:
ones_python
,
'aten::zeros'
:
zeros_python
,
'aten::rsub'
:
rsub_python
,
'aten::expand'
:
expand_python
,
'prim::TupleUnpack'
:
tupleunpack_python
,
'prim::TupleUnpack'
:
tupleunpack_python
,
'prim::ListUnpack'
:
tupleunpack_python
,
'prim::ListUnpack'
:
tupleunpack_python
,
'prim::NumToTensor'
:
num2tensor_python
,
'prim::NumToTensor'
:
num2tensor_python
,
'prim::GetAttr'
:
getattr_python
,
'prim::GetAttr'
:
getattr_python
,
'prim::Constant'
:
constant_python
,
'aten::constant_pad_nd'
:
constant_pad_nd_python
,
'aten::silu'
:
silu_python
,
'aten::expand_as'
:
expandas_python
}
}
def
init_add_functions
(
func_from
:
Union
[
ModuleType
,
Type
[
Any
]]):
"""
Add function/method attributes from a module/class, to the trans_func_dict
Parameters
---------
func_from
The module/class include needed functions
def
jit_to_python_function
(
node
,
speedup
):
"""
"""
Return a callable object to inference the mask according to the
global
trans_func_dict
node.op_type.
new_trans_func_dict
=
dict
()
for
name
in
dir
(
func_from
):
attr
=
getattr
(
func_from
,
name
)
if
callable
(
attr
)
and
not
name
.
startswith
(
'__'
):
new_trans_func_dict
[
'aten::'
+
name
]
=
partial
(
generate_aten_to_python
,
attr
)
trans_func_dict
=
{
**
new_trans_func_dict
,
**
trans_func_dict
}
init_add_functions
(
torch
.
_C
.
_VariableFunctions
)
init_add_functions
(
torch
.
_C
.
_nn
)
init_add_functions
(
torch
.
_C
.
_TensorBase
)
def
jit_to_python_function
(
node
:
NodePyGroup
,
speedup
:
ModelSpeedup
)
->
FuncAdapter
:
"""
Return a callable object to inference the mask according to the node.op_type.
Parameters
Parameters
---------
---------
node
: NodeGroup
node
The target node to inference the mask
The target node to inference the mask
speedup
: ModelSpeedup
speedup
The speedup object of the target model.
The speedup object of the target model.
Returns
Returns
------
------
func
: callable object(nn.Module/function)
func
Return the translated function that used to inference the mask
Return the translated function that used to inference the mask
, if current op_type is not supported, then we return None.
, if current op_type is not supported, then we return None.
"""
"""
logger
.
debug
(
logger
.
debug
(
'Translate C function %s into its python version'
,
node
.
op_type
)
'Translate C function %s into its python version'
,
node
.
op_type
)
if
node
.
op_type
not
in
trans_f
rom_jit_to_python
:
if
node
.
op_type
not
in
trans_f
unc_dict
:
logger
.
error
(
logger
.
error
(
'%s is not Supported! Please report an issue at https://github.com/microsoft/nni. Thanks~'
,
node
.
op_type
)
'%s is not Supported! Please report an issue at https://github.com/microsoft/nni. Thanks~'
,
node
.
op_type
)
# return None to skip the mask inference for this node
# return None to skip the mask inference for this node
return
None
return
None
return
trans_f
rom_jit_to_python
[
node
.
op_type
](
node
,
speedup
)
return
trans_f
unc_dict
[
node
.
op_type
](
node
,
speedup
)
test/algo/compression/v1/test_model_speedup.py
View file @
588f299b
...
@@ -61,7 +61,7 @@ class BackboneModel2(torch.nn.Module):
...
@@ -61,7 +61,7 @@ class BackboneModel2(torch.nn.Module):
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
relu
(
self
.
bn2
(
self
.
conv2
(
x
)))
x
=
F
.
relu
(
self
.
bn2
(
self
.
conv2
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
x
.
size
(
0
),
-
1
)
x
=
x
.
reshape
(
x
.
size
(
0
),
-
1
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
x
=
self
.
fc2
(
x
)
...
...
test/algo/compression/v2/test_auto_conv.py
0 → 100644
View file @
588f299b
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
unittest
import
torch
import
torch.nn.functional
as
F
from
nni.compression.pytorch.pruning
import
L1NormPruner
from
nni.compression.pytorch.speedup
import
ModelSpeedup
from
nni.algorithms.compression.v2.pytorch.utils
import
(
compute_sparsity_compact2origin
,
compute_sparsity_mask2compact
)
class
CondModel
(
torch
.
nn
.
Module
):
"""
test for:
prim::If
"""
the_cond
:
bool
def
__init__
(
self
):
super
().
__init__
()
self
.
the_cond
=
True
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
the_cond
:
x
=
x
+
0.00001
else
:
x
=
x
-
0.00001
self
.
the_cond
=
not
self
.
the_cond
return
x
class
ASubModel
(
torch
.
nn
.
Module
):
"""
test for:
sub model
"""
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
x
+
0.00001
return
x
class
TorchModel1
(
torch
.
nn
.
Module
):
"""
test for:
add, sub, mul, div, exp, matmul,
relu, gelu, tanh, silu, sigmod, softmax,
size, unsqueeze, flatten, cat, slice, reshape, transpose, t, select, permute, constant_pad_nd,
mean, avg_pool2d, max_pool2d, sum, adaptive_avg_pool2d,
to, Int, view,
type_as, expand_as, contiguous,
notes:
'floor_divide' have no backward, then not be tested
"""
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
6
,
5
,
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
6
,
16
,
5
,
1
)
self
.
fccond
=
torch
.
nn
.
Linear
(
16
*
4
*
4
,
16
*
4
*
4
)
self
.
fc1
=
torch
.
nn
.
Linear
(
16
*
4
*
4
,
120
)
self
.
fc2
=
torch
.
nn
.
Linear
(
120
,
84
)
self
.
fc3
=
torch
.
nn
.
Linear
(
84
,
10
)
self
.
pool1
=
torch
.
nn
.
MaxPool2d
((
2
,
2
))
self
.
pool2
=
torch
.
nn
.
MaxPool2d
((
2
,
2
))
self
.
cond
=
torch
.
jit
.
script
(
CondModel
())
self
.
asub
=
ASubModel
()
def
forward
(
self
,
x
:
torch
.
Tensor
):
x
=
x
.
contiguous
(
memory_format
=
torch
.
channels_last
)
x
=
torch
.
_C
.
_nn
.
upsample_bilinear2d
(
x
,
(
28
,
28
),
False
)
x
=
torch
.
_C
.
_nn
.
upsample_nearest2d
(
x
,
(
28
,
28
))
x
=
F
.
adaptive_avg_pool2d
(
x
,
(
28
,
28
))
x
=
torch
.
exp
(
x
)
x
=
torch
.
sigmoid
(
x
)
x
=
torch
.
transpose
(
x
,
1
,
2
)
x
=
torch
.
transpose
(
x
,
1
,
2
)
x
=
F
.
avg_pool2d
(
x
,
3
,
1
,
padding
=
1
)
x
=
F
.
max_pool2d
(
x
,
3
,
1
,
padding
=
1
)
x
=
x
.
to
(
torch
.
float32
)
x
=
self
.
conv1
(
x
)
y1
=
self
.
pool1
(
F
.
relu
(
x
))
y2
=
self
.
pool1
(
F
.
gelu
(
x
))
x
=
y1
+
y2
x
=
x
+
0.00001
x
=
x
*
1.00001
x
=
self
.
conv2
(
x
)
y1
=
self
.
pool2
(
F
.
silu
(
x
))
y2
=
self
.
pool2
(
torch
.
tanh
(
x
))
x
=
y1
-
y2
x
=
x
-
0.00001
x
=
x
/
1.00001
x
=
torch
.
permute
(
x
,
(
0
,
2
,
3
,
1
))
x
=
torch
.
permute
(
x
,
(
0
,
2
,
3
,
1
))
x
=
torch
.
permute
(
x
,
(
0
,
2
,
3
,
1
))
x
=
torch
.
unsqueeze
(
x
,
dim
=
1
)
x
=
torch
.
select
(
x
,
dim
=
1
,
index
=
0
)
x
=
torch
.
unsqueeze
(
x
,
dim
=
1
)
x
=
torch
.
mean
(
x
,
dim
=
1
)
x
=
torch
.
unsqueeze
(
x
,
dim
=
1
)
x
=
torch
.
sum
(
x
,
dim
=
1
,
dtype
=
torch
.
float32
)
x
=
torch
.
unsqueeze
(
x
,
dim
=
1
)
x
=
torch
.
squeeze
(
x
,
dim
=
1
)
x
=
torch
.
flatten
(
x
,
1
)
x
=
x
.
reshape
(
x
.
shape
)
x
=
x
.
view
(
-
1
,
x
.
size
(
1
))
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
F
.
relu
(
self
.
fc2
(
x
))
x
=
F
.
softmax
(
self
.
fc3
(
x
),
dim
=
1
)
y1
=
x
[:,
0
:
int
(
x
.
size
(
1
)
/
2
)]
y2
=
x
[:,
int
(
x
.
size
(
1
)
/
2
):
x
.
size
(
1
)]
x
=
torch
.
cat
((
y1
,
y2
),
dim
=
1
)
x
=
x
.
type_as
(
x
)
x
=
x
.
expand_as
(
x
)
x
=
torch
.
matmul
(
x
,
x
.
t
())
x
=
torch
.
cat
([
x
,
x
],
dim
=
1
)
# x = self.cond(x)
x
=
self
.
asub
(
x
)
x
=
torch
.
constant_pad_nd
(
x
,
(
1
,
1
,
1
,
1
),
3.14159
)
return
x
class
AutoConvTestCase
(
unittest
.
TestCase
):
def
test_l1norm_pruner
(
self
):
model
=
TorchModel1
()
dummy_input
=
torch
.
rand
(
3
,
1
,
28
,
28
)
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.5
}]
pruner
=
L1NormPruner
(
model
=
model
,
config_list
=
config_list
)
pruned_model
,
masks
=
pruner
.
compress
()
pruner
.
_unwrap_model
()
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
ModelSpeedup
(
model
,
dummy_input
,
masks
).
speedup_model
()
real_sparsity_list
=
compute_sparsity_compact2origin
(
TorchModel1
(),
model
,
config_list
)
print
(
'sparsity_list:'
,
sparsity_list
)
assert
0.45
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.55
print
(
'real_sparsity_list:'
,
real_sparsity_list
)
assert
0.45
<
real_sparsity_list
[
0
][
'total_sparsity'
]
<
0.75
print
(
'the shape of output of the infer:'
,
model
(
dummy_input
).
shape
)
assert
model
(
dummy_input
).
shape
==
torch
.
Size
((
5
,
8
))
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment