Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
588f299b
Unverified
Commit
588f299b
authored
Aug 16, 2022
by
Louis-J
Committed by
GitHub
Aug 16, 2022
Browse files
feat(speedup): automatically convert op asts to callables (#4996)
parent
b2c31ca2
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
481 additions
and
595 deletions
+481
-595
nni/compression/pytorch/speedup/jit_translate.py
nni/compression/pytorch/speedup/jit_translate.py
+315
-594
test/algo/compression/v1/test_model_speedup.py
test/algo/compression/v1/test_model_speedup.py
+1
-1
test/algo/compression/v2/test_auto_conv.py
test/algo/compression/v2/test_auto_conv.py
+165
-0
No files found.
nni/compression/pytorch/speedup/jit_translate.py
View file @
588f299b
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
from
types
import
ModuleType
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Type
,
Union
from
typing
import
TYPE_CHECKING
if
TYPE_CHECKING
:
# Only imports the below statements during type checking
from
nni.compression.pytorch.speedup
import
ModelSpeedup
from
nni.common.graph_utils
import
NodePyGroup
import
re
import
logging
from
functools
import
partial
from
functools
import
partial
,
lru_cache
import
copy
import
torch
...
...
@@ -15,31 +24,24 @@ jitid_2_dtype = {4: torch.long, 6:torch.float32}
# to exclude partial
__all__
=
[
'adaptive_avgpool_python'
,
'add_python'
,
'avgpool2d_python'
,
'cat_python'
,
'contiguous_python'
,
'div_python'
,
'dropout_python'
,
'exp_python'
,
'flatten_python'
,
'floor_div_python'
,
'gelu_python'
,
'getattr_python'
,
'jit_to_python_function'
,
'matmul_python'
,
'mean_python'
,
'mul_python'
,
'num2tensor_python'
,
'parse_constant'
,
'permute_python'
,
'relu_inplace_python'
,
'relu_python'
,
'reshape_python'
,
'select_python'
,
'sigmoid_python'
,
'size_python'
,
'slice_python'
,
'softmax_python'
,
'squeeze_python'
,
'to_python'
,
'toint_python'
,
'torch'
,
'trans_from_jit_to_python'
,
'translate_list'
,
'transpose2_python'
,
'transpose_python'
,
'tupleunpack_python'
,
'typeas_python'
,
'unsqueeze_python'
,
'upsample_bilinear2d_python'
,
'view_python'
'getattr_python'
,
'jit_to_python_function'
,
'num2tensor_python'
,
'parse_constant'
,
'slice_python'
,
'translate_list'
,
'tupleunpack_python'
,
'dtype_trans'
,
'memory_format_trans'
]
def
translate_list
(
list_node
,
speedup
=
None
):
def
translate_list
(
list_node
:
torch
.
_C
.
Value
,
speedup
:
ModelSpeedup
=
None
)
->
List
:
"""
Get the list of values from the list construct node.
Parameters
----------
list_node
: Torch.C.Value
list_node
The cpp node of the target list.
speedup
: ModuleSpeed
speedup
The Module speedup module.
Returns
-------
values
: list
values
The list of values in the target cpp list node.
"""
# the node that create the list
...
...
@@ -52,27 +54,26 @@ def translate_list(list_node, speedup=None):
if
speedup
is
not
None
and
debugName
in
speedup
.
internal_result
:
# this value is the result of the other nodes, such as
# ate::size
values
.
append
(
speedup
.
internal_result
[
debugName
]
.
item
()
)
values
.
append
(
speedup
.
internal_result
[
debugName
])
else
:
# if the corresponding value is a constant
values
.
append
(
_i
.
toIValue
())
return
values
def
parse_constant
(
cvalue
,
speedup
):
def
parse_constant
(
cvalue
:
torch
.
_C
.
Value
,
speedup
:
ModelSpeedup
)
->
Any
:
"""
Parse the constant values from this Node
Parameters
----------
cvalue
: Torch.C.Value
cvalue
The cpp node of the target constant value.
speedup
: ModelSpeedup
speedup
The Model speedup module.
Returns
-------
value
: int/float/tensor
value
The constant values parsed from the node.
"""
logger
.
debug
(
'Try to parse the constant value: %s'
,
cvalue
.
debugName
())
...
...
@@ -85,245 +86,13 @@ def parse_constant(cvalue, speedup):
inputs
=
op_node
.
inputs
()
input_values
=
[
parse_constant
(
_i
,
speedup
)
for
_i
in
inputs
]
func
=
trans_from_jit_to_python
[
op_node
.
kind
()](
op_node
,
speedup
)
return
func
(
*
input_values
)
def
dropout_python
(
node
,
speedup
):
return
torch
.
dropout
def
flatten_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
start_dim
=
inputs
[
1
].
toIValue
()
end_dim
=
inputs
[
2
].
toIValue
()
new_flatten
=
partial
(
torch
.
flatten
,
start_dim
=
start_dim
,
end_dim
=
end_dim
)
return
new_flatten
def
relu_inplace_python
(
node
,
speedup
):
return
torch
.
relu_
def
relu_python
(
node
,
speedup
):
return
torch
.
relu
def
sigmoid_python
(
node
,
speedup
):
return
torch
.
sigmoid
def
mean_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim_list
=
translate_list
(
inputs
[
1
],
speedup
)
keep_dim
=
inputs
[
2
].
toIValue
()
new_mean
=
partial
(
torch
.
mean
,
dim
=
tuple
(
dim_list
),
keepdim
=
keep_dim
)
return
new_mean
def
add_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
constant
=
None
for
i
in
range
(
2
):
input_i
=
inputs
[
i
]
debug_name
=
input_i
.
debugName
()
if
debug_name
not
in
speedup
.
internal_result
:
# this input is a constant value
# TODO: what if this input is a constant tensor
if
input_i
.
toIValue
()
is
not
None
:
constant
=
parse_constant
(
input_i
,
speedup
)
break
if
constant
is
None
:
return
torch
.
add
else
:
new_add
=
partial
(
torch
.
add
,
constant
)
return
new_add
def
sub_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
constant
=
[
None
,
None
]
for
i
in
range
(
2
):
input_i
=
inputs
[
i
]
debug_name
=
input_i
.
debugName
()
if
debug_name
not
in
speedup
.
internal_result
:
# this input is a constant value
# TODO: what if this input is a constant tensor
if
input_i
.
toIValue
()
is
not
None
:
constant
[
i
]
=
parse_constant
(
input_i
,
speedup
)
break
if
constant
[
0
]
is
None
and
constant
[
1
]
is
None
:
new_sub
=
torch
.
sub
elif
constant
[
0
]
is
not
None
:
new_sub
=
partial
(
torch
.
sub
,
input
=
constant
)
else
:
new_sub
=
partial
(
torch
.
sub
,
other
=
constant
)
return
new_sub
def
floor_div_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
divisor
=
inputs
[
1
]
constant
=
None
if
divisor
.
debugName
()
not
in
speedup
.
internal_result
:
# divisor is a constant value/tensor
constant
=
parse_constant
(
divisor
,
speedup
)
if
constant
is
None
:
return
torch
.
floor_divide
else
:
new_op
=
partial
(
torch
.
floor_divide
,
other
=
constant
)
return
new_op
def
mul_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
constant
=
None
for
i
in
range
(
2
):
input_i
=
inputs
[
i
]
debug_name
=
input_i
.
debugName
()
if
debug_name
not
in
speedup
.
internal_result
:
constant
=
parse_constant
(
input_i
,
speedup
)
# both two inputs cannot be constants at the same time
break
if
constant
is
None
:
return
torch
.
mul
else
:
new_mul
=
partial
(
torch
.
mul
,
constant
)
return
new_mul
def
transpose_python
(
node
,
speedup
):
return
torch
.
t
def
transpose2_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim_1
=
inputs
[
1
].
toIValue
()
dim_2
=
inputs
[
2
].
toIValue
()
new_transpose
=
partial
(
torch
.
transpose
,
dim0
=
dim_1
,
dim1
=
dim_2
)
return
new_transpose
def
matmul_python
(
node
,
speedup
):
return
torch
.
matmul
def
div_python
(
node
,
speedup
):
# The second input parameter of torch.div can be a
# tensor or a constant, if it is a constant, we need
# to return
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
if
inputs
[
1
].
debugName
()
in
speedup
.
internal_result
:
# the second input parameters is the output of the other
# nodes
return
torch
.
div
else
:
other
=
inputs
[
1
].
toIValue
()
new_div
=
partial
(
torch
.
div
,
other
=
other
)
return
new_div
def
softmax_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim
=
inputs
[
1
].
toIValue
()
new_softmax
=
partial
(
torch
.
softmax
,
dim
=
dim
)
return
new_softmax
def
contiguous_python
(
node
,
speedup
):
class
contiguousModule
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
.
contiguous
().
clone
()
return
contiguousModule
()
def
gelu_python
(
node
,
speedup
):
return
torch
.
nn
.
GELU
()
def
silu_python
(
node
,
speedup
):
return
torch
.
nn
.
SiLU
()
def
avgpool2d_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
kernel_size
=
translate_list
(
inputs
[
1
],
speedup
)
stride
=
translate_list
(
inputs
[
2
],
speedup
)
padding
=
translate_list
(
inputs
[
3
],
speedup
)
new_avgpool
=
partial
(
torch
.
nn
.
functional
.
avg_pool2d
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
)
return
new_avgpool
def
adaptive_avgpool_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
output_size
=
translate_list
(
inputs
[
1
],
speedup
)
new_avgpool
=
torch
.
nn
.
AdaptiveAvgPool2d
(
output_size
)
return
new_avgpool
def
tupleunpack_python
(
node
,
speedup
):
# Note: tuple unpack should only exists at the
# the end of the model, and is no need to replace/propagate mask
return
None
def
num2tensor_python
(
node
,
speedup
):
return
torch
.
nn
.
Identity
()
def
exp_python
(
node
,
speedup
):
return
torch
.
exp
def
squeeze_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim
=
None
if
len
(
inputs
)
>
1
:
dim
=
parse_constant
(
inputs
[
1
],
speedup
)
new_squeeze
=
partial
(
torch
.
squeeze
,
dim
=
dim
)
return
new_squeeze
def
unsqueeze_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim
=
parse_constant
(
inputs
[
1
],
speedup
)
new_unsqueeze
=
partial
(
torch
.
unsqueeze
,
dim
=
dim
)
return
new_unsqueeze
def
constant_pad_nd_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
pad
=
translate_list
(
inputs
[
1
],
speedup
)
value
=
parse_constant
(
inputs
[
2
],
speedup
)
new_constant_pad_nd
=
partial
(
torch
.
nn
.
functional
.
pad
,
pad
=
pad
,
value
=
value
)
return
new_constant_pad_nd
##########################################################
# Split Line
# Following module/functions cannot be translated into a
# single function, so we use torch.nn.Module to wrap the
# the core function, and return the torch.nn.Module instead
##########################################################
if
op_node
.
kind
()
not
in
trans_func_dict
:
raise
RuntimeError
(
'Unsupported function op node type: {}'
.
format
(
op_node
.
kind
()))
func
=
trans_func_dict
[
op_node
.
kind
()](
op_node
,
speedup
)
return
func
(
*
input_values
)
def
slice_python
(
node
,
s
peedup
):
def
slice_python
(
node
:
NodePyGroup
,
speedup
:
ModelS
peedup
):
class
SliceMoudle
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
sliceobj
):
super
(
SliceMoudle
,
self
).
__init__
()
...
...
@@ -368,102 +137,38 @@ def slice_python(node, speedup):
else
:
return
SliceMoudle
(
tuple
(
slice_list
))
def
select_python
(
node
,
speedup
):
class
SelectModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
index
):
super
(
SelectModule
,
self
).
__init__
()
self
.
dim
=
copy
.
deepcopy
(
dim
)
self
.
index
=
copy
.
deepcopy
(
index
)
def
forward
(
self
,
x
):
return
x
.
select
(
self
.
dim
,
self
.
index
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim
=
inputs
[
1
].
toIValue
()
index
=
inputs
[
2
].
toIValue
()
return
SelectModule
(
dim
,
index
)
def
size_python
(
node
,
speedup
):
# return None
class
SizeMoudle
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
sizedim
):
super
(
SizeMoudle
,
self
).
__init__
()
self
.
sizedim
=
sizedim
def
forward
(
self
,
x
):
return
torch
.
as_tensor
([
x
.
size
(
self
.
sizedim
)],
dtype
=
torch
.
long
)
# return torch.tensor(x.size(self.sizedim))
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
size_dim
=
inputs
[
1
].
toIValue
()
return
SizeMoudle
(
size_dim
)
def
toint_python
(
node
,
speedup
):
class
ToIntModule
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
.
to
(
torch
.
int
)
return
ToIntModule
()
def
view_python
(
node
,
speedup
):
class
ViewModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
shape
):
super
(
ViewModule
,
self
).
__init__
()
self
.
shape
=
shape
logger
.
info
(
'View Module output size: %s'
,
str
(
self
.
shape
))
def
cat_python
(
node
:
NodePyGroup
,
speedup
:
ModelSpeedup
):
class
CatModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
cat_dim
):
super
(
CatModule
,
self
).
__init__
()
self
.
cat_dim
=
cat_dim
def
forward
(
self
,
*
args
):
return
args
[
0
].
view
(
self
.
shape
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
shape
=
translate_list
(
inputs
[
1
],
speedup
)
return
ViewModule
(
shape
)
def
reshape_python
(
node
,
speedup
):
class
ReshapeModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
shape
):
super
(
ReshapeModule
,
self
).
__init__
()
self
.
shape
=
shape
logger
.
info
(
'Reshape Module output size: %s'
,
str
(
self
.
shape
))
return
torch
.
cat
(
args
,
dim
=
self
.
cat_dim
)
def
forward
(
self
,
*
args
):
return
args
[
0
].
reshape
(
self
.
shape
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
shape
=
translate_list
(
inputs
[
1
],
speedup
)
return
ReshapeModule
(
shape
)
def
permute_python
(
node
,
speedup
):
class
PermuteModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dimlist
):
super
(
PermuteModule
,
self
).
__init__
()
# deepcopy the values here, because the following randomize operation
# will change the value of the dimlist
self
.
dimlist
=
copy
.
deepcopy
(
dimlist
)
dim
=
inputs
[
1
].
toIValue
()
return
CatModule
(
dim
)
def
forward
(
self
,
x
):
return
x
.
permute
(
self
.
dimlist
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim_list
=
translate_list
(
inputs
[
1
],
speedup
)
return
PermuteModule
(
dim_list
)
def
tupleunpack_python
(
_node
:
NodePyGroup
,
_speedup
:
ModelSpeedup
)
->
Optional
[
Callable
]:
# Note: tuple unpack should only exists at the
# the end of the model, and is no need to replace/propagate mask
return
None
def
num2tensor_python
(
_node
:
NodePyGroup
,
_speedup
:
ModelSpeedup
):
return
torch
.
nn
.
Identity
()
def
getattr_python
(
node
,
s
peedup
):
def
getattr_python
(
node
:
NodePyGroup
,
_speedup
:
ModelS
peedup
):
"""
Note: Ops started with Prim:: is not taken as the key node,
so we directly pass the Cpp node into this funciton.
Parameters
----------
node
: torch._C.Node
node
The cpp node of prim::Getattr
speedup
: ModelSpeedup
speedup
The corresponding speedup object.
"""
class
GetModule
(
torch
.
nn
.
Module
):
...
...
@@ -482,316 +187,332 @@ def getattr_python(node, speedup):
assert
len
(
key_words
)
==
1
return
GetModule
(
key_words
[
0
])
def
constant_python
(
node
,
speedup
)
:
class
FuncAdapter
:
"""
get the constant value of constant operator node.
A function adapter which can reorder arguments.
It can be initialate with constant argument, and positions of each non-constant
argument. When called, it can put arguments into correct position, then call the
function.
Parame
te
r
s
Attribu
tes
----------
node: torch._C.Node
The cpp node of prim::Getattr
speedup: ModelSpeedup
The corresponding speedup object.
func
The function or method to be called.
positional
Positional arguments values. The placeholder is None if it's non-constant.
keyword
Keyword arguments values. The placeholder is None if it's non-constant.
undetermined
A list of the right positions of arguments.
Position is an int in positional or a str in keyword.
special_treat
A Dict of the positions and methods.
The values of these positions should be treat by those methods.
"""
class
ConstantModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
constant
):
super
(
ConstantModule
,
self
).
__init__
()
self
.
constant
=
constant
def
forward
(
self
):
return
self
.
constant
assert
node
.
kind
()
==
'prim::Constant'
pattern
=
'\[value=(.*?)\]'
key_words
=
re
.
findall
(
pattern
,
str
(
node
))
if
len
(
key_words
)
==
0
:
return
ConstantModule
(
None
)
assert
len
(
key_words
)
==
1
# parse the constant value
value
=
key_words
[
0
]
if
value
.
startswith
(
"
\"
"
):
value
=
torch
.
device
(
value
[
1
:
-
1
])
elif
value
.
startswith
(
'{'
):
# TODO Support set values in the future
value
=
set
()
elif
'.'
in
value
:
# float value
value
=
float
(
value
)
def
__init__
(
self
,
func
:
Callable
,
positional
:
List
[
Any
],
keyword
:
Dict
[
str
,
Any
],
undetermined
:
List
[
Union
[
int
,
str
]],
special_treat
:
Dict
[
Union
[
int
,
str
],
Callable
]):
if
not
callable
(
func
):
raise
TypeError
(
'the "func" argument must be callable'
)
self
.
func
=
func
self
.
positional
=
positional
self
.
keyword
=
keyword
self
.
undetermined
=
undetermined
self
.
special_treat
=
special_treat
def
__call__
(
self
,
/
,
*
args
):
assert
len
(
args
)
>=
len
(
self
.
undetermined
)
if
len
(
args
)
>
len
(
self
.
undetermined
):
logger
.
warning
(
'throw some args away when calling the function "%s"'
,
self
.
func
.
__name__
)
for
i
,
p
in
enumerate
(
self
.
undetermined
):
v
=
args
[
i
]
if
isinstance
(
p
,
int
):
self
.
positional
[
p
]
=
v
else
:
# integer value
value
=
int
(
value
)
return
ConstantModule
(
value
)
self
.
keyword
[
p
]
=
v
def
upsample_bilinear2d_python
(
node
,
speedup
):
class
UpsampleModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
size_list
,
scale_list
):
super
(
UpsampleModule
,
self
).
__init__
()
self
.
size_list
=
size_list
self
.
scale_list
=
scale_list
for
p
,
fs
in
self
.
special_treat
.
items
():
if
isinstance
(
p
,
int
):
for
f
in
fs
:
self
.
positional
[
p
]
=
f
(
self
.
positional
[
p
])
else
:
for
f
in
fs
:
self
.
keyword
[
p
]
=
f
(
self
.
keyword
[
p
])
result
=
self
.
func
(
*
self
.
positional
,
**
self
.
keyword
)
if
isinstance
(
result
,
int
):
# turn result of 'size' into tensor
result
=
torch
.
as_tensor
([
result
],
dtype
=
torch
.
long
)
return
result
# There are some types that will be convert into enums after jit.
# So we should recover them back:
# device, dtype, layout, memory_format, qscheme, qengine, dispatchkey
enum_to_dtype_names
=
{
0
:
'uint8'
,
1
:
'int8'
,
2
:
'int16'
,
3
:
'int32'
,
4
:
'int64'
,
5
:
'float16'
,
6
:
'float32'
,
7
:
'float64'
,
8
:
'complex32'
,
9
:
'complex64'
,
10
:
'complex128'
,
11
:
'bool'
,
12
:
'qint8'
,
13
:
'quint8'
,
14
:
'qint32'
,
15
:
'bfloat16'
,
16
:
'quint4x2'
,
17
:
'quint2x4'
,
}
def
forward
(
self
,
*
args
):
"""
The first input of args is the target tensor to upsample
, the following parameters is useless, because we already
get the size_list and the scale_list by parsing the cpp_nodes.
"""
return
torch
.
nn
.
functional
.
upsample_bilinear
(
args
[
0
],
size
=
self
.
size_list
,
scale_factor
=
self
.
scale_list
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
size_list_node
=
inputs
[
1
].
node
()
scale_list_node
=
inputs
[
3
].
node
()
size_list
=
None
scale_list
=
None
enum_to_dtype_dict
=
{}
if
size_list_node
.
kind
()
==
'prim::ListConstruct'
:
size_list
=
translate_list
(
inputs
[
1
],
speedup
)
if
scale_list_node
.
kind
()
==
'prim::ListConstruct'
:
scale_list
=
translate_list
(
inputs
[
3
],
speedup
)
return
UpsampleModule
(
size_list
,
scale_list
)
for
enum_value
,
dtype_name
in
enum_to_dtype_names
.
items
():
if
hasattr
(
torch
,
dtype_name
):
enum_to_dtype_dict
[
enum_value
]
=
getattr
(
torch
,
dtype_name
)
def
dtype_trans
(
ivalue
:
Union
[
int
,
torch
.
dtype
]):
"""
Special process for dtype.
Torch will transform dtype to an enum in cpp, so the value of dtype we get in jit is an int.
This function is used to recover the int to torch.dtype in python.
def
upsample_nearest2d_python
(
node
,
speedup
):
class
UpsampleModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
size_list
,
scale_list
):
super
(
UpsampleModule
,
self
).
__init__
()
self
.
size_list
=
size_list
self
.
scale_list
=
scale_list
Parameters
----------
ivalue
The value of dtype or method to be recovered.
def
forward
(
self
,
*
args
):
"""
The first input of args is the target tensor to upsample
, the following parameters is useless, because we already
get the size_list and the scale_list by parsing the cpp_nodes.
if
ivalue
is
None
or
isinstance
(
ivalue
,
torch
.
dtype
):
return
ivalue
elif
isinstance
(
ivalue
,
int
):
if
ivalue
in
enum_to_dtype_dict
:
return
enum_to_dtype_dict
[
ivalue
]
raise
TypeError
(
'No torch.dtype corresponding to the value "%s"'
,
ivalue
)
enum_to_memory_format_dict
=
{
0
:
torch
.
contiguous_format
,
1
:
torch
.
preserve_format
,
2
:
torch
.
channels_last
,
3
:
torch
.
channels_last_3d
,
}
def
memory_format_trans
(
ivalue
:
Union
[
int
,
torch
.
memory_format
]):
"""
return
torch
.
nn
.
functional
.
upsample_nearest
(
args
[
0
],
size
=
self
.
size_list
,
scale_factor
=
self
.
scale_list
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
size_list_node
=
inputs
[
1
].
node
()
scale_list_node
=
inputs
[
2
].
node
()
size_list
=
None
scale_list
=
None
Special process for memory_format.
Torch will transform memory_format to an enum in cpp, so the value of memory_format we get in jit is an int.
This function is used to recover the int to torch.memory_format in python.
if
size_list_node
.
kind
()
==
'prim::ListConstruct'
:
size_list
=
translate_list
(
inputs
[
1
],
speedup
)
if
scale_list_node
.
kind
()
==
'prim::ListConstruct'
:
scale_list
=
translate_list
(
inputs
[
2
],
speedup
)
return
UpsampleModule
(
size_list
,
scale_list
)
Parameters
----------
ivalue
The value of memory_format or method to be recovered.
"""
if
ivalue
is
None
or
isinstance
(
ivalue
,
torch
.
memory_format
):
return
ivalue
elif
isinstance
(
ivalue
,
int
):
global
enum_to_memory_format_dict
if
ivalue
in
enum_to_memory_format_dict
:
return
enum_to_memory_format_dict
[
ivalue
]
raise
TypeError
(
'No torch.memory_format corresponding to the value "%s"'
,
ivalue
)
special_treat_dict
=
{
'dtype'
:
dtype_trans
,
'memory_format'
:
memory_format_trans
,
}
def
typeas_python
(
node
,
speedup
):
schema_fix_dict
=
{
# functinon 'to', 'randint', and 'sparse_coo_tensor' has different schema between python and c++.
# https://pytorch.org/docs/stable/jit_unsupported.html#ops-with-divergent-schemas-between-torch-python
"""aten::to.device(Tensor(a) self, Device device, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Ten
sor(a))"""
:
"""aten::to.device(Tensor(a) self, Device device, int dtype, bool non_blocking=False, bool copy=False, *, int? memory_format=None)
-> (Tensor(a))"""
,
'aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))'
:
'aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, *, int? memory_format=None) -> (Tensor(a))'
,
'aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))'
:
'aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, *, int? memory_format=None) -> (Tensor(a))'
,
# todo: are the arguments 'pin_memory' and 'requires_grad' related?
# functions in the python have only 'requires_grad' and functions in the aten have only 'pin_memory'
# 'aten::randint(int high, int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)',
# """aten::randint.generator(int high, int[] size, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, boo
# l? pin_memory=None) -> (Tensor)""",
# """aten::randint.low(int low, int high, int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None)
# -> (Tensor)""",
# """aten::randint.low_generator(int low, int high, int[] size, *, Generator? generator, int? dtype=None, int? layout=None, Device? dev
# ice=None, bool? pin_memory=None) -> (Tensor)""",
# """aten::sparse_coo_tensor.size(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=False) -> (Te
# nsor)""",
# """aten::sparse_coo_tensor.indices(Tensor indices, Tensor values, *, int? dtype=None, int? layout=None, Device? device=None, bool? pi
# n_memory=None) -> (Tensor)""",
# """aten::sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, int? dtype=None, int? layout=None, Device? devi
# ce=None, bool? pin_memory=None) -> (Tensor"""'
}
@
lru_cache
(
maxsize
=
256
)
def
parse_aten_schema
(
schema
:
str
):
"""
currently only support type_as float.
TODO: support more types in the type_as, need to figure out
how to get the scalar type from torch._C.TensorType.
Parse the schema, to positional_num and keyword_list, and detect if the argument should be specially treated.
"""
class
TypeasModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dtype
=
torch
.
float
):
self
.
example
=
torch
.
zeros
(
1
,
dtype
=
dtype
)
def
forward
(
self
,
x
):
return
x
.
type_as
(
self
.
example
)
return
TypeasModule
()
if
schema
in
schema_fix_dict
:
schema
=
schema_fix_dict
[
schema
]
positional_num
=
0
keyword_list
=
list
()
special_treat
=
dict
()
# for dtype and memory_format trans now
def
to_python
(
node
,
speedup
):
# for the time being, only device parameters are supported
class
ToModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
device
,
dtype
):
super
(
ToModule
,
self
).
__init__
()
self
.
device
=
device
self
.
dtype
=
dtype
def
forward
(
self
,
x
):
return
x
.
to
(
device
,
dtype
=
self
.
dtype
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
in_debugname
=
inputs
[
0
].
debugName
()
# device of the input tensor
device
=
speedup
.
internal_result
[
in_debugname
].
device
for
_
,
_node
in
enumerate
(
inputs
[
1
:]):
val
=
parse_constant
(
_node
,
speedup
)
if
isinstance
(
val
,
torch
.
device
):
device
=
val
dtype
=
jitid_2_dtype
[
parse_constant
(
inputs
[
1
],
speedup
)]
return
ToModule
(
device
,
dtype
)
for
arg
in
torch
.
_C
.
parse_schema
(
schema
).
arguments
:
if
not
arg
.
kwarg_only
:
key
=
positional_num
positional_num
+=
1
else
:
key
=
arg
.
name
keyword_list
.
append
(
key
)
if
arg
.
name
in
special_treat_dict
:
if
key
not
in
special_treat
:
special_treat
[
key
]
=
[
special_treat_dict
[
arg
.
name
]]
else
:
special_treat
[
key
].
append
(
special_treat_dict
[
arg
.
name
])
def
cat_python
(
node
,
speedup
):
class
CatModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
cat_dim
):
super
(
CatModule
,
self
).
__init__
()
self
.
cat_dim
=
cat_dim
return
positional_num
,
keyword_list
,
special_treat
def
forward
(
self
,
*
args
):
return
torch
.
cat
(
args
,
dim
=
self
.
cat_dim
)
def
parse_input_value
(
speedup
:
ModelSpeedup
,
input_nodes
:
List
[
torch
.
_C
.
Node
],
positional_num
:
int
,
keyword_list
:
List
[
str
]):
"""
translate inputs, to constant positional arguments, constant keyword arguments, and undetermined positions
"""
positional
=
list
()
keyword
=
dict
()
undetermined
=
list
()
for
ainput
in
input_nodes
:
if
ainput
.
node
().
kind
()
==
'prim::ListConstruct'
:
arg
=
translate_list
(
ainput
,
speedup
)
elif
ainput
.
node
().
kind
()
==
'prim::Constant'
:
arg
=
ainput
.
toIValue
()
else
:
assert
'aten::'
in
ainput
.
node
().
kind
()
or
'prim::'
in
ainput
.
node
().
kind
()
if
len
(
positional
)
<
positional_num
:
undetermined
.
append
(
len
(
positional
))
else
:
undetermined
.
append
(
keyword_list
[
positional_num
-
len
(
positional
)])
arg
=
None
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim
=
inputs
[
1
].
toIValue
()
return
CatModule
(
dim
)
if
len
(
positional
)
<
positional_num
:
positional
.
append
(
arg
)
else
:
keyword
[
keyword_list
[
positional_num
-
len
(
positional
)]]
=
arg
return
positional
,
keyword
,
undetermined
def
special_treat_to_constant_value
(
positional
:
List
,
keyword
:
Dict
[
str
],
undetermined
:
List
[
Union
[
int
,
str
]],
special_treat
:
Dict
[
Union
[
int
,
str
],
Callable
]):
"""
if any argument with special_treat is not in undetermined, do the treat
"""
undetermined_special_treat
=
dict
()
for
p
,
fs
in
special_treat
.
items
():
if
p
in
undetermined
:
undetermined_special_treat
[
p
]
=
fs
elif
isinstance
(
p
,
int
):
for
f
in
fs
:
positional
[
p
]
=
f
(
positional
[
p
])
else
:
for
f
in
fs
:
keyword
[
p
]
=
f
(
keyword
[
p
])
return
undetermined_special_treat
def
ones_python
(
node
,
speedup
):
class
OnesModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
out_size
,
dtype_id
,
device
,
require_grad
):
super
(
OnesModule
,
self
).
__init__
()
self
.
out_size
=
out_size
self
.
device
=
device
self
.
require_grad
=
require_grad
self
.
dtype
=
jitid_2_dtype
[
dtype_id
]
def
generate_aten_to_python
(
func
:
Callable
,
node
:
NodePyGroup
,
speedup
:
ModelSpeedup
)
->
FuncAdapter
:
"""
parse a Return a callable object to inference the mask according to the node.op_type.
def
forward
(
self
,
*
args
):
return
torch
.
ones
(
size
=
self
.
out_size
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
requires_grad
=
self
.
require_grad
)
Parameters
---------
func
The torch function one-to-one correspondence with the node.
node
The target node to inference the mask
speedup
The speedup object of the target model.
Returns
------
func
Return the translated function that used to inference the mask
, if current op_type is not supported, then we return None.
"""
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
output_shape
=
translate_list
(
inputs
[
0
],
speedup
)
dtype_id
=
parse_constant
(
inputs
[
1
],
speedup
)
# layout = parse_constant(inputs[2], speedup)
device
=
parse_constant
(
inputs
[
3
],
speedup
)
require_grad
=
parse_constant
(
inputs
[
4
],
speedup
)
return
OnesModule
(
output_shape
,
dtype_id
,
device
,
require_grad
)
def
zeros_python
(
node
,
speedup
):
class
ZerosModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
out_size
,
dtype_id
,
device
,
require_grad
):
super
(
ZerosModule
,
self
).
__init__
()
self
.
out_size
=
out_size
self
.
device
=
device
self
.
require_grad
=
require_grad
self
.
dtype
=
jitid_2_dtype
[
dtype_id
]
def
forward
(
self
,
*
args
):
return
torch
.
zeros
(
size
=
self
.
out_size
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
requires_grad
=
self
.
require_grad
)
schema
=
c_node
.
schema
()
positional_num
,
keyword_list
,
special_treat
=
parse_aten_schema
(
schema
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
output_shape
=
translate_list
(
inputs
[
0
],
speedup
)
dtype_id
=
parse_constant
(
inputs
[
1
],
speedup
)
# layout = parse_constant(inputs[2], speedup)
device
=
parse_constant
(
inputs
[
3
],
speedup
)
require_grad
=
parse_constant
(
inputs
[
4
],
speedup
)
return
ZerosModule
(
output_shape
,
dtype_id
,
device
,
require_grad
)
def
rsub_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
constant
=
None
other_name
=
inputs
[
1
].
debugName
()
alpha
=
parse_constant
(
inputs
[
2
],
speedup
)
if
other_name
not
in
speedup
.
internal_result
:
constant
=
parse_constant
(
inputs
[
1
],
speedup
)
if
constant
is
None
:
return
torch
.
sub
()
else
:
new_sub
=
partial
(
torch
.
sub
,
other
=
constant
,
alpha
=
alpha
)
return
new_sub
input_nodes
=
list
(
c_node
.
inputs
())
positional
,
keyword
,
undetermined
=
parse_input_value
(
speedup
,
input_nodes
,
positional_num
,
keyword_list
)
def
expand_python
(
node
,
speedup
):
class
ExpandModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
new_size
):
super
(
ExpandModule
,
self
).
__init__
()
# need deepcopy when the input is size-related
self
.
new_size
=
copy
.
deepcopy
(
new_size
)
undetermined_special_treat
=
special_treat_to_constant_value
(
positional
,
keyword
,
undetermined
,
special_treat
)
def
forward
(
self
,
*
args
):
return
args
[
0
].
expand
(
self
.
new_size
).
clone
()
return
FuncAdapter
(
func
,
positional
,
keyword
,
undetermined
,
undetermined_special_treat
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
new_size
=
translate_list
(
inputs
[
1
],
speedup
)
return
ExpandModule
(
new_size
)
def
expandas_python
(
node
,
speedup
):
class
ExpandasModule
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
return
x
.
expand_as
(
y
).
clone
()
return
ExpandasModule
()
trans_from_jit_to_python
=
{
'aten::add'
:
add_python
,
'aten::add_'
:
add_python
,
'aten::sub'
:
sub_python
,
'aten::sub_'
:
sub_python
,
'aten::mul'
:
mul_python
,
'aten::mul_'
:
mul_python
,
'aten::relu'
:
relu_python
,
'aten::relu_'
:
relu_inplace_python
,
'aten::sigmoid'
:
sigmoid_python
,
'aten::sigmoid_'
:
sigmoid_python
,
# tanh behaives like relu
'aten::tanh'
:
relu_python
,
'aten::tanh_'
:
relu_python
,
'aten::flatten'
:
flatten_python
,
'aten::mean'
:
mean_python
,
'aten::dropout'
:
dropout_python
,
trans_func_dict
=
{
'aten::slice'
:
slice_python
,
'aten::select'
:
select_python
,
'aten::size'
:
size_python
,
'aten::t'
:
transpose_python
,
'aten::transpose'
:
transpose2_python
,
'aten::Int'
:
toint_python
,
'aten::view'
:
view_python
,
'aten::reshape'
:
reshape_python
,
'aten::permute'
:
permute_python
,
'aten::matmul'
:
matmul_python
,
'aten::div'
:
div_python
,
'aten::floor_divide'
:
floor_div_python
,
'aten::softmax'
:
softmax_python
,
'aten::contiguous'
:
contiguous_python
,
'aten::gelu'
:
gelu_python
,
'aten::cat'
:
cat_python
,
'aten::avg_pool2d'
:
avgpool2d_python
,
'aten::max_pool2d'
:
avgpool2d_python
,
'aten::adaptive_avg_pool2d'
:
adaptive_avgpool_python
,
'aten::to'
:
to_python
,
'aten::type_as'
:
typeas_python
,
'aten::upsample_bilinear2d'
:
upsample_bilinear2d_python
,
'aten::upsample_nearest2d'
:
upsample_nearest2d_python
,
'aten::exp'
:
exp_python
,
'aten::squeeze'
:
squeeze_python
,
'aten::unsqueeze'
:
unsqueeze_python
,
'aten::ones'
:
ones_python
,
'aten::zeros'
:
zeros_python
,
'aten::rsub'
:
rsub_python
,
'aten::expand'
:
expand_python
,
'aten::Int'
:
partial
(
generate_aten_to_python
,
torch
.
_C
.
_TensorBase
.
int
),
'prim::TupleUnpack'
:
tupleunpack_python
,
'prim::ListUnpack'
:
tupleunpack_python
,
'prim::NumToTensor'
:
num2tensor_python
,
'prim::GetAttr'
:
getattr_python
,
'prim::Constant'
:
constant_python
,
'aten::constant_pad_nd'
:
constant_pad_nd_python
,
'aten::silu'
:
silu_python
,
'aten::expand_as'
:
expandas_python
}
def
init_add_functions
(
func_from
:
Union
[
ModuleType
,
Type
[
Any
]]):
"""
Add function/method attributes from a module/class, to the trans_func_dict
Parameters
---------
func_from
The module/class include needed functions
def
jit_to_python_function
(
node
,
speedup
):
"""
Return a callable object to inference the mask according to the
node.op_type.
global
trans_func_dict
new_trans_func_dict
=
dict
()
for
name
in
dir
(
func_from
):
attr
=
getattr
(
func_from
,
name
)
if
callable
(
attr
)
and
not
name
.
startswith
(
'__'
):
new_trans_func_dict
[
'aten::'
+
name
]
=
partial
(
generate_aten_to_python
,
attr
)
trans_func_dict
=
{
**
new_trans_func_dict
,
**
trans_func_dict
}
init_add_functions
(
torch
.
_C
.
_VariableFunctions
)
init_add_functions
(
torch
.
_C
.
_nn
)
init_add_functions
(
torch
.
_C
.
_TensorBase
)
def
jit_to_python_function
(
node
:
NodePyGroup
,
speedup
:
ModelSpeedup
)
->
FuncAdapter
:
"""
Return a callable object to inference the mask according to the node.op_type.
Parameters
---------
node
: NodeGroup
node
The target node to inference the mask
speedup
: ModelSpeedup
speedup
The speedup object of the target model.
Returns
------
func
: callable object(nn.Module/function)
func
Return the translated function that used to inference the mask
, if current op_type is not supported, then we return None.
"""
logger
.
debug
(
'Translate C function %s into its python version'
,
node
.
op_type
)
if
node
.
op_type
not
in
trans_f
rom_jit_to_python
:
if
node
.
op_type
not
in
trans_f
unc_dict
:
logger
.
error
(
'%s is not Supported! Please report an issue at https://github.com/microsoft/nni. Thanks~'
,
node
.
op_type
)
# return None to skip the mask inference for this node
return
None
return
trans_f
rom_jit_to_python
[
node
.
op_type
](
node
,
speedup
)
return
trans_f
unc_dict
[
node
.
op_type
](
node
,
speedup
)
test/algo/compression/v1/test_model_speedup.py
View file @
588f299b
...
...
@@ -61,7 +61,7 @@ class BackboneModel2(torch.nn.Module):
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
relu
(
self
.
bn2
(
self
.
conv2
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
x
.
size
(
0
),
-
1
)
x
=
x
.
reshape
(
x
.
size
(
0
),
-
1
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
...
...
test/algo/compression/v2/test_auto_conv.py
0 → 100644
View file @
588f299b
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
unittest
import
torch
import
torch.nn.functional
as
F
from
nni.compression.pytorch.pruning
import
L1NormPruner
from
nni.compression.pytorch.speedup
import
ModelSpeedup
from
nni.algorithms.compression.v2.pytorch.utils
import
(
compute_sparsity_compact2origin
,
compute_sparsity_mask2compact
)
class
CondModel
(
torch
.
nn
.
Module
):
"""
test for:
prim::If
"""
the_cond
:
bool
def
__init__
(
self
):
super
().
__init__
()
self
.
the_cond
=
True
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
the_cond
:
x
=
x
+
0.00001
else
:
x
=
x
-
0.00001
self
.
the_cond
=
not
self
.
the_cond
return
x
class
ASubModel
(
torch
.
nn
.
Module
):
"""
test for:
sub model
"""
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
x
+
0.00001
return
x
class
TorchModel1
(
torch
.
nn
.
Module
):
"""
test for:
add, sub, mul, div, exp, matmul,
relu, gelu, tanh, silu, sigmod, softmax,
size, unsqueeze, flatten, cat, slice, reshape, transpose, t, select, permute, constant_pad_nd,
mean, avg_pool2d, max_pool2d, sum, adaptive_avg_pool2d,
to, Int, view,
type_as, expand_as, contiguous,
notes:
'floor_divide' have no backward, then not be tested
"""
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
6
,
5
,
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
6
,
16
,
5
,
1
)
self
.
fccond
=
torch
.
nn
.
Linear
(
16
*
4
*
4
,
16
*
4
*
4
)
self
.
fc1
=
torch
.
nn
.
Linear
(
16
*
4
*
4
,
120
)
self
.
fc2
=
torch
.
nn
.
Linear
(
120
,
84
)
self
.
fc3
=
torch
.
nn
.
Linear
(
84
,
10
)
self
.
pool1
=
torch
.
nn
.
MaxPool2d
((
2
,
2
))
self
.
pool2
=
torch
.
nn
.
MaxPool2d
((
2
,
2
))
self
.
cond
=
torch
.
jit
.
script
(
CondModel
())
self
.
asub
=
ASubModel
()
def
forward
(
self
,
x
:
torch
.
Tensor
):
x
=
x
.
contiguous
(
memory_format
=
torch
.
channels_last
)
x
=
torch
.
_C
.
_nn
.
upsample_bilinear2d
(
x
,
(
28
,
28
),
False
)
x
=
torch
.
_C
.
_nn
.
upsample_nearest2d
(
x
,
(
28
,
28
))
x
=
F
.
adaptive_avg_pool2d
(
x
,
(
28
,
28
))
x
=
torch
.
exp
(
x
)
x
=
torch
.
sigmoid
(
x
)
x
=
torch
.
transpose
(
x
,
1
,
2
)
x
=
torch
.
transpose
(
x
,
1
,
2
)
x
=
F
.
avg_pool2d
(
x
,
3
,
1
,
padding
=
1
)
x
=
F
.
max_pool2d
(
x
,
3
,
1
,
padding
=
1
)
x
=
x
.
to
(
torch
.
float32
)
x
=
self
.
conv1
(
x
)
y1
=
self
.
pool1
(
F
.
relu
(
x
))
y2
=
self
.
pool1
(
F
.
gelu
(
x
))
x
=
y1
+
y2
x
=
x
+
0.00001
x
=
x
*
1.00001
x
=
self
.
conv2
(
x
)
y1
=
self
.
pool2
(
F
.
silu
(
x
))
y2
=
self
.
pool2
(
torch
.
tanh
(
x
))
x
=
y1
-
y2
x
=
x
-
0.00001
x
=
x
/
1.00001
x
=
torch
.
permute
(
x
,
(
0
,
2
,
3
,
1
))
x
=
torch
.
permute
(
x
,
(
0
,
2
,
3
,
1
))
x
=
torch
.
permute
(
x
,
(
0
,
2
,
3
,
1
))
x
=
torch
.
unsqueeze
(
x
,
dim
=
1
)
x
=
torch
.
select
(
x
,
dim
=
1
,
index
=
0
)
x
=
torch
.
unsqueeze
(
x
,
dim
=
1
)
x
=
torch
.
mean
(
x
,
dim
=
1
)
x
=
torch
.
unsqueeze
(
x
,
dim
=
1
)
x
=
torch
.
sum
(
x
,
dim
=
1
,
dtype
=
torch
.
float32
)
x
=
torch
.
unsqueeze
(
x
,
dim
=
1
)
x
=
torch
.
squeeze
(
x
,
dim
=
1
)
x
=
torch
.
flatten
(
x
,
1
)
x
=
x
.
reshape
(
x
.
shape
)
x
=
x
.
view
(
-
1
,
x
.
size
(
1
))
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
F
.
relu
(
self
.
fc2
(
x
))
x
=
F
.
softmax
(
self
.
fc3
(
x
),
dim
=
1
)
y1
=
x
[:,
0
:
int
(
x
.
size
(
1
)
/
2
)]
y2
=
x
[:,
int
(
x
.
size
(
1
)
/
2
):
x
.
size
(
1
)]
x
=
torch
.
cat
((
y1
,
y2
),
dim
=
1
)
x
=
x
.
type_as
(
x
)
x
=
x
.
expand_as
(
x
)
x
=
torch
.
matmul
(
x
,
x
.
t
())
x
=
torch
.
cat
([
x
,
x
],
dim
=
1
)
# x = self.cond(x)
x
=
self
.
asub
(
x
)
x
=
torch
.
constant_pad_nd
(
x
,
(
1
,
1
,
1
,
1
),
3.14159
)
return
x
class
AutoConvTestCase
(
unittest
.
TestCase
):
def
test_l1norm_pruner
(
self
):
model
=
TorchModel1
()
dummy_input
=
torch
.
rand
(
3
,
1
,
28
,
28
)
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.5
}]
pruner
=
L1NormPruner
(
model
=
model
,
config_list
=
config_list
)
pruned_model
,
masks
=
pruner
.
compress
()
pruner
.
_unwrap_model
()
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
ModelSpeedup
(
model
,
dummy_input
,
masks
).
speedup_model
()
real_sparsity_list
=
compute_sparsity_compact2origin
(
TorchModel1
(),
model
,
config_list
)
print
(
'sparsity_list:'
,
sparsity_list
)
assert
0.45
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.55
print
(
'real_sparsity_list:'
,
real_sparsity_list
)
assert
0.45
<
real_sparsity_list
[
0
][
'total_sparsity'
]
<
0.75
print
(
'the shape of output of the infer:'
,
model
(
dummy_input
).
shape
)
assert
model
(
dummy_input
).
shape
==
torch
.
Size
((
5
,
8
))
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment