Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
97d067e6
Unverified
Commit
97d067e6
authored
Aug 08, 2022
by
Ningxin Zheng
Committed by
GitHub
Aug 08, 2022
Browse files
Speedup enhancement (#4925)
parent
4ab85d3d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
314 additions
and
51 deletions
+314
-51
nni/compression/pytorch/speedup/compress_modules.py
nni/compression/pytorch/speedup/compress_modules.py
+44
-12
nni/compression/pytorch/speedup/compressor.py
nni/compression/pytorch/speedup/compressor.py
+17
-10
nni/compression/pytorch/speedup/infer_mask.py
nni/compression/pytorch/speedup/infer_mask.py
+13
-10
nni/compression/pytorch/speedup/jit_translate.py
nni/compression/pytorch/speedup/jit_translate.py
+158
-15
nni/compression/pytorch/utils/utils.py
nni/compression/pytorch/utils/utils.py
+82
-4
No files found.
nni/compression/pytorch/speedup/compress_modules.py
View file @
97d067e6
...
...
@@ -45,6 +45,7 @@ replace_module = {
'Upsample'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'LayerNorm'
:
lambda
module
,
masks
:
replace_layernorm
(
module
,
masks
),
'ConvTranspose2d'
:
lambda
module
,
masks
:
replace_convtranspose2d
(
module
,
masks
),
'Embedding'
:
lambda
module
,
masks
:
replace_embedding
(
module
,
masks
),
'PixelShuffle'
:
lambda
module
,
masks
:
replace_pixelshuffle
(
module
,
masks
),
'Flatten'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
)
}
...
...
@@ -85,6 +86,30 @@ def convert_to_coarse_mask(t_mask, dim):
return
indexes
,
remained_indexes
def
convert_dense_shape
(
mask
):
"""
Get the dense shape of the tensor after removing the sparsity
values.
Parameters
----------
mask: torch.Tensor
The mask tensor.
Returns
-------
dense_shape: tuple
The dense shape after removing the sparsity values.
"""
assert
isinstance
(
mask
,
torch
.
Tensor
)
n_dim
=
len
(
mask
.
size
())
dense_shape
=
[]
for
dim
in
range
(
n_dim
):
_
,
remained
=
convert_to_coarse_mask
(
mask
,
dim
)
dense_shape
.
append
(
remained
.
size
(
0
))
return
tuple
(
dense_shape
)
def
no_replace
(
module
,
masks
):
"""
No need to replace
...
...
@@ -165,9 +190,12 @@ def replace_linear(linear, masks):
in_mask
=
in_masks
[
0
]
weight_mask
=
weight_mask
[
'weight'
]
# the input of the linear may have two dimensions(CV models) or three
# dimensions(Bert, for example)
n_dim
=
len
(
in_mask
.
size
())
# N C K
pruned_in
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
pruned_out
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
1
)
pruned_in
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
n_dim
-
1
)
pruned_out
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
n_dim
-
1
)
n_remained_in
=
weight_mask
.
size
(
1
)
-
pruned_in
.
size
(
0
)
n_remained_out
=
weight_mask
.
size
(
0
)
-
pruned_out
.
size
(
0
)
remained_in
,
remained_out
=
remained_in
.
to
(
...
...
@@ -582,16 +610,20 @@ def replace_layernorm(layernorm, masks):
if
len
(
in_masks
)
!=
1
:
raise
InputsNumberError
()
in_mask
=
in_masks
[
0
]
dim_n
=
len
(
in_mask
.
size
())
new_shape
=
[]
for
i
in
range
(
1
,
dim_n
):
sum_dims
=
list
(
range
(
0
,
dim_n
))
sum_dims
.
remove
(
i
)
reduced
=
torch
.
sum
(
in_mask
,
sum_dims
)
n_remained
=
torch
.
sum
(
reduced
>
0
)
new_shape
.
append
(
n_remained
)
return
nn
.
LayerNorm
(
tuple
(
new_shape
),
layernorm
.
eps
,
layernorm
.
elementwise_affine
)
dense_shape
=
convert_dense_shape
(
in_mask
)
norm_shape
=
layernorm
.
normalized_shape
dim_n
=
len
(
dense_shape
)
-
len
(
norm_shape
)
return
nn
.
LayerNorm
(
dense_shape
[
dim_n
:],
layernorm
.
eps
,
layernorm
.
elementwise_affine
)
def
replace_embedding
(
embedding
,
masks
):
"""
Replace the embedding layer according the infered masks.
We replace the embedding layer according the weight masks,
"""
# currently we donnot support replace the embedding layer
# because we donnot have the corressponding pruner
return
embedding
def
replace_pixelshuffle
(
pixelshuffle
,
masks
):
...
...
nni/compression/pytorch/speedup/compressor.py
View file @
97d067e6
...
...
@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import
copy
import
logging
from
pathlib
import
Path
import
queue
...
...
@@ -66,6 +67,7 @@ class ModelSpeedup:
self
.
bound_model
=
model
self
.
inferred_masks
=
dict
()
# key: module_name, value: ModuleMasks
self
.
batch_dim
=
batch_dim
self
.
confidence
=
confidence
self
.
dummy_input
,
self
.
device
=
self
.
_random_model_input
(
dummy_input
,
confidence
,
batch_dim
)
self
.
torch_graph
=
build_module_graph
(
model
,
self
.
dummy_input
)
...
...
@@ -196,6 +198,7 @@ class ModelSpeedup:
# The detach operation here is for the in-place operation. We cannot
# directly can the backward on the output tensor of an in-place operator.
dummy_input
.
append
(
self
.
internal_result
[
_input
].
detach
())
debugnames
.
append
(
_input
)
return
dummy_input
,
debugnames
...
...
@@ -229,15 +232,15 @@ class ModelSpeedup:
return
# function doesn't have weights
_auto_infer
=
AutoMaskInference
(
func
,
dummy_input
,
in_masks
,
in_constants
=
in_constants
,
batch_dim
=
self
.
batch_dim
)
func
,
dummy_input
,
self
,
in_masks
,
in_constants
=
in_constants
)
else
:
weight_mask
=
None
if
module_name
in
self
.
masks
:
weight_mask
=
self
.
masks
[
module_name
]
_
,
module
=
get_module_by_name
(
self
.
bound_model
,
module_name
)
_auto_infer
=
AutoMaskInference
(
module
,
dummy_input
,
in_masks
,
weight_mask
,
in_constants
=
in_constants
,
state_dict
=
copy
.
deepcopy
(
module
.
state_dict
())
,
batch_dim
=
self
.
batch_dim
)
module
,
dummy_input
,
self
,
in_masks
,
weight_mask
,
in_constants
=
in_constants
,
state_dict
=
copy
.
deepcopy
(
module
.
state_dict
()))
self
.
auto_inferences
[
unique_name
]
=
_auto_infer
_auto_infer
.
name
=
node
.
unique_name
...
...
@@ -280,6 +283,7 @@ class ModelSpeedup:
The target node to update the indirect sparsity
"""
unique_name
=
node
.
unique_name
if
unique_name
in
self
.
auto_inferences
and
self
.
auto_inferences
[
unique_name
]
is
not
None
:
# if the auto inference object already in self.auto_inference, then
# directly update the previous one
...
...
@@ -291,13 +295,18 @@ class ModelSpeedup:
# pass the gradient to the predecessor nodes
for
in_id
,
tin
in
enumerate
(
auto_infer
.
dummy_input
):
debug_name
=
auto_infer
.
input_debugname
[
in_id
]
last_output
=
self
.
internal_result
[
debug_name
]
# if isinstance(last_output, torch.Tensor):
# TODO what if last output is tuple/list of tensor
if
last_output
.
grad
is
not
None
and
tin
.
grad
is
not
None
:
last_output
.
grad
.
data
+=
tin
.
grad
.
data
el
s
e
:
el
if
last_output
.
grad
is
Non
e
:
last_output
.
grad
=
tin
.
grad
elif
last_output
.
grad
is
not
None
and
tin
.
grad
is
None
:
# for example, tin.view(batch, tin.size(1)/2, tin.view(2)*2)
# the size operation of tin will have no gradient
continue
else
:
_logger
.
warning
(
'Note: %s does not have corresponding mask inference object'
,
node
.
name
)
...
...
@@ -388,6 +397,7 @@ class ModelSpeedup:
if
out_degree
[
predecessor
]
==
0
:
visit_queue
.
put
(
self
.
torch_graph
.
name_to_node
[
predecessor
])
def
replace_compressed_modules
(
self
):
"""
Replace all the modules that have changed (weights/inputs/output) shape.
...
...
@@ -401,6 +411,7 @@ class ModelSpeedup:
for
unique_name
in
self
.
auto_inferences
:
self
.
replace_submodule
(
unique_name
)
def
replace_submodule
(
self
,
unique_name
,
reindex_dim
=
None
,
reindex
=
None
):
"""
Replace the submodule according to the inferred sparsity.
...
...
@@ -443,7 +454,6 @@ class ModelSpeedup:
requires_grad
=
tmpout
.
requires_grad
)
out
[
self
.
t_index
]
=
tmpout
return
out
assert
unique_name
in
self
.
auto_inferences
g_node
=
self
.
torch_graph
.
name_to_node
[
unique_name
]
_logger
.
debug
(
"replace %s, in %s type, with op_type %s"
,
...
...
@@ -483,12 +493,9 @@ class ModelSpeedup:
setattr
(
super_module
,
g_node
.
name
.
split
(
'.'
)[
-
1
],
new_submodule
)
return
new_submodule
elif
g_node
.
type
==
'func'
:
_logger
.
info
(
"Warning: cannot replace (name: %s, op_type: %s) which is func type"
,
unique_name
,
g_node
.
op_type
)
return
None
else
:
raise
RuntimeError
(
"Unsupported node type: {}"
.
format
(
g_node
.
type
))
return
None
def
initialize_speedup
(
self
):
"""
...
...
nni/compression/pytorch/speedup/infer_mask.py
View file @
97d067e6
...
...
@@ -12,8 +12,8 @@ STD_DELTA = 1e-6
class
AutoMaskInference
:
def
__init__
(
self
,
module
,
dummy_input
,
in_masks
=
None
,
weight_mask
=
None
,
\
output_mask
=
None
,
name
=
None
,
in_constants
=
None
,
state_dict
=
None
,
batch_dim
=
0
):
def
__init__
(
self
,
module
,
dummy_input
,
speedup
,
in_masks
=
None
,
weight_mask
=
None
,
output_mask
=
None
,
name
=
None
,
in_constants
=
None
,
state_dict
=
None
):
"""
This class will infer the mask of the target module automatically.
This update_direct_sparsity will infer the output mask according
...
...
@@ -28,6 +28,8 @@ class AutoMaskInference:
The target module to infer the mask. Need to be callable.
dummy_input: torch.Tensor/list of Tensor
The dummy_input of the target module.
speedup: ModelSpeedup
The reference of the ModelSpeedup object.
in_masks: list of torch.Tensor
The input masks of the target module, if in_masks is not None, then
update_direct_sparsity and update_indirect_sparsity will incrementally
...
...
@@ -47,8 +49,6 @@ class AutoMaskInference:
The correponding constant values of the in_masks.
state_dict: dict of torch.Tensor
The original values of the weights.
batch_dim: int
The index of the batch dimension of the input tensors.
"""
errmsg
=
'%s is not callable, should pass the nn.Module/function'
%
str
(
...
...
@@ -112,7 +112,8 @@ class AutoMaskInference:
self
.
weight_mask
[
name
]
=
torch
.
ones_like
(
para
.
data
)
self
.
state_dict
=
state_dict
# TODO support the other batch dimension in the future
self
.
batch_dim
=
batch_dim
self
.
batch_dim
=
speedup
.
batch_dim
self
.
batch_size
=
speedup
.
confidence
def
random_init
(
self
,
start
=
0.1
,
end
=
8.0
):
"""
...
...
@@ -125,13 +126,17 @@ class AutoMaskInference:
# rules for ReLU6 to break this range constraint.
with
torch
.
no_grad
():
for
tensor
in
self
.
dummy_input
:
if
isinstance
(
tensor
,
torch
.
Tensor
)
and
len
(
tensor
.
size
())
>
0
:
# if the tensor is a scalar, then skip this tensor
if
isinstance
(
tensor
,
torch
.
Tensor
)
and
len
(
tensor
.
size
())
>
self
.
batch_dim
\
and
tensor
.
size
(
self
.
batch_dim
)
==
self
.
batch_size
:
# if the input tensor only has one dimension, which means
# it doesn't have the batch dimension, then we don't randomize
# this tensor, because our tensor scrambling is on the batch
# dimention. For example, if the tensor is a scalar(returned
# by the size operator), then we will skip this tensor
randomize_tensor
(
tensor
,
start
,
end
)
for
para
in
self
.
weights
:
randomize_tensor
(
self
.
weights
[
para
].
data
,
start
,
end
)
def
zero_grad
(
self
):
"""
Set the gradient of the weight, input tensor to be zeros.
...
...
@@ -240,7 +245,6 @@ class AutoMaskInference:
constant
[:,
mask_pos
]
=
mean
[
mask_pos
]
return
out_mask
,
constant
def
update_indirect_sparsity
(
self
):
"""
This function will update the indirect sparsity. To explain what's
...
...
@@ -379,4 +383,3 @@ class AutoMaskInference:
def
get_masks
(
self
):
return
(
self
.
in_masks
,
self
.
output_mask
,
self
.
weight_mask
)
nni/compression/pytorch/speedup/jit_translate.py
View file @
97d067e6
...
...
@@ -4,11 +4,13 @@
import
re
import
logging
from
functools
import
partial
import
copy
import
torch
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
INFO
)
jitid_2_dtype
=
{
4
:
torch
.
long
,
6
:
torch
.
float32
}
# to exclude partial
...
...
@@ -243,7 +245,7 @@ def softmax_python(node, speedup):
def
contiguous_python
(
node
,
speedup
):
class
contiguousModule
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
.
contiguous
()
return
x
.
contiguous
()
.
clone
()
return
contiguousModule
()
...
...
@@ -297,6 +299,7 @@ def squeeze_python(node, speedup):
new_squeeze
=
partial
(
torch
.
squeeze
,
dim
=
dim
)
return
new_squeeze
def
unsqueeze_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
...
...
@@ -324,7 +327,10 @@ def slice_python(node, speedup):
class
SliceMoudle
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
sliceobj
):
super
(
SliceMoudle
,
self
).
__init__
()
self
.
sliceobj
=
sliceobj
# we need to deepcopy the value here, because, in the
# follwing steps, we may randomize the input tensor
# which will change the values of the sliceobj
self
.
sliceobj
=
copy
.
deepcopy
(
sliceobj
)
def
forward
(
self
,
x
,
*
args
):
# args is for the slice dimension and indexes, however,
...
...
@@ -344,20 +350,31 @@ def slice_python(node, speedup):
slice_end
=
parse_constant
(
inputs
[
3
],
speedup
)
slice_step
=
parse_constant
(
inputs
[
4
],
speedup
)
slice_obj
=
slice
(
slice_start
,
slice_end
,
slice_step
)
slice_list
=
[]
for
_
in
range
(
slice_dim
):
slice_list
.
append
(
slice
(
None
,
None
))
logger
.
info
(
'Slice dim:%s, Slice obj:%s'
,
str
(
slice_dim
),
str
(
slice_obj
))
slice_list
.
append
(
slice_obj
)
return
SliceMoudle
(
tuple
(
slice_list
))
if
inputs
[
0
].
debugName
()
not
in
speedup
.
internal_result
:
# The inputs of slice operator may be the constant
target_tensor
=
parse_constant
(
inputs
[
0
],
speedup
)
slice_list
=
tuple
(
slice_list
)
def
constant_slice
(
*
args
):
return
target_tensor
[
slice_list
]
return
constant_slice
else
:
return
SliceMoudle
(
tuple
(
slice_list
))
def
select_python
(
node
,
speedup
):
class
SelectModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
index
):
super
(
SelectModule
,
self
).
__init__
()
self
.
dim
=
dim
self
.
index
=
index
self
.
dim
=
copy
.
deepcopy
(
dim
)
self
.
index
=
copy
.
deepcopy
(
index
)
def
forward
(
self
,
x
):
return
x
.
select
(
self
.
dim
,
self
.
index
)
...
...
@@ -425,7 +442,9 @@ def permute_python(node, speedup):
class
PermuteModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dimlist
):
super
(
PermuteModule
,
self
).
__init__
()
self
.
dimlist
=
dimlist
# deepcopy the values here, because the following randomize operation
# will change the value of the dimlist
self
.
dimlist
=
copy
.
deepcopy
(
dimlist
)
def
forward
(
self
,
x
):
return
x
.
permute
(
self
.
dimlist
)
...
...
@@ -439,6 +458,7 @@ def getattr_python(node, speedup):
"""
Note: Ops started with Prim:: is not taken as the key node,
so we directly pass the Cpp node into this funciton.
Parameters
----------
node: torch._C.Node
...
...
@@ -462,6 +482,44 @@ def getattr_python(node, speedup):
assert
len
(
key_words
)
==
1
return
GetModule
(
key_words
[
0
])
def
constant_python
(
node
,
speedup
):
"""
get the constant value of constant operator node.
Parameters
----------
node: torch._C.Node
The cpp node of prim::Getattr
speedup: ModelSpeedup
The corresponding speedup object.
"""
class
ConstantModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
constant
):
super
(
ConstantModule
,
self
).
__init__
()
self
.
constant
=
constant
def
forward
(
self
):
return
self
.
constant
assert
node
.
kind
()
==
'prim::Constant'
pattern
=
'\[value=(.*?)\]'
key_words
=
re
.
findall
(
pattern
,
str
(
node
))
if
len
(
key_words
)
==
0
:
return
ConstantModule
(
None
)
assert
len
(
key_words
)
==
1
# parse the constant value
value
=
key_words
[
0
]
if
value
.
startswith
(
"
\"
"
):
value
=
torch
.
device
(
value
[
1
:
-
1
])
elif
value
.
startswith
(
'{'
):
# TODO Support set values in the future
value
=
set
()
elif
'.'
in
value
:
# float value
value
=
float
(
value
)
else
:
# integer value
value
=
int
(
value
)
return
ConstantModule
(
value
)
def
upsample_bilinear2d_python
(
node
,
speedup
):
class
UpsampleModule
(
torch
.
nn
.
Module
):
...
...
@@ -539,16 +597,25 @@ def typeas_python(node, speedup):
def
to_python
(
node
,
speedup
):
# for the time being, only device parameters are supported
class
ToModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
device
):
def
__init__
(
self
,
device
,
dtype
):
super
(
ToModule
,
self
).
__init__
()
self
.
device
=
device
self
.
dtype
=
dtype
def
forward
(
self
,
x
):
return
x
.
to
(
device
)
return
x
.
to
(
device
,
dtype
=
self
.
dtype
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
device
=
inputs
[
3
].
toIValue
()
return
ToModule
(
device
)
in_debugname
=
inputs
[
0
].
debugName
()
# device of the input tensor
device
=
speedup
.
internal_result
[
in_debugname
].
device
for
_
,
_node
in
enumerate
(
inputs
[
1
:]):
val
=
parse_constant
(
_node
,
speedup
)
if
isinstance
(
val
,
torch
.
device
):
device
=
val
dtype
=
jitid_2_dtype
[
parse_constant
(
inputs
[
1
],
speedup
)]
return
ToModule
(
device
,
dtype
)
def
cat_python
(
node
,
speedup
):
...
...
@@ -566,6 +633,77 @@ def cat_python(node, speedup):
return
CatModule
(
dim
)
def
ones_python
(
node
,
speedup
):
class
OnesModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
out_size
,
dtype_id
,
device
,
require_grad
):
super
(
OnesModule
,
self
).
__init__
()
self
.
out_size
=
out_size
self
.
device
=
device
self
.
require_grad
=
require_grad
self
.
dtype
=
jitid_2_dtype
[
dtype_id
]
def
forward
(
self
,
*
args
):
return
torch
.
ones
(
size
=
self
.
out_size
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
requires_grad
=
self
.
require_grad
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
output_shape
=
translate_list
(
inputs
[
0
],
speedup
)
dtype_id
=
parse_constant
(
inputs
[
1
],
speedup
)
# layout = parse_constant(inputs[2], speedup)
device
=
parse_constant
(
inputs
[
3
],
speedup
)
require_grad
=
parse_constant
(
inputs
[
4
],
speedup
)
return
OnesModule
(
output_shape
,
dtype_id
,
device
,
require_grad
)
def
zeros_python
(
node
,
speedup
):
class
ZerosModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
out_size
,
dtype_id
,
device
,
require_grad
):
super
(
ZerosModule
,
self
).
__init__
()
self
.
out_size
=
out_size
self
.
device
=
device
self
.
require_grad
=
require_grad
self
.
dtype
=
jitid_2_dtype
[
dtype_id
]
def
forward
(
self
,
*
args
):
return
torch
.
zeros
(
size
=
self
.
out_size
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
requires_grad
=
self
.
require_grad
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
output_shape
=
translate_list
(
inputs
[
0
],
speedup
)
dtype_id
=
parse_constant
(
inputs
[
1
],
speedup
)
# layout = parse_constant(inputs[2], speedup)
device
=
parse_constant
(
inputs
[
3
],
speedup
)
require_grad
=
parse_constant
(
inputs
[
4
],
speedup
)
return
ZerosModule
(
output_shape
,
dtype_id
,
device
,
require_grad
)
def
rsub_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
constant
=
None
other_name
=
inputs
[
1
].
debugName
()
alpha
=
parse_constant
(
inputs
[
2
],
speedup
)
if
other_name
not
in
speedup
.
internal_result
:
constant
=
parse_constant
(
inputs
[
1
],
speedup
)
if
constant
is
None
:
return
torch
.
sub
()
else
:
new_sub
=
partial
(
torch
.
sub
,
other
=
constant
,
alpha
=
alpha
)
return
new_sub
def
expand_python
(
node
,
speedup
):
class
ExpandModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
new_size
):
super
(
ExpandModule
,
self
).
__init__
()
# need deepcopy when the input is size-related
self
.
new_size
=
copy
.
deepcopy
(
new_size
)
def
forward
(
self
,
*
args
):
return
args
[
0
].
expand
(
self
.
new_size
).
clone
()
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
new_size
=
translate_list
(
inputs
[
1
],
speedup
)
return
ExpandModule
(
new_size
)
def
expandas_python
(
node
,
speedup
):
class
ExpandasModule
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
...
...
@@ -616,13 +754,18 @@ trans_from_jit_to_python = {
'aten::exp'
:
exp_python
,
'aten::squeeze'
:
squeeze_python
,
'aten::unsqueeze'
:
unsqueeze_python
,
'aten::constant_pad_nd'
:
constant_pad_nd_python
,
'aten::silu'
:
silu_python
,
'aten::expand_as'
:
expandas_python
,
'aten::ones'
:
ones_python
,
'aten::zeros'
:
zeros_python
,
'aten::rsub'
:
rsub_python
,
'aten::expand'
:
expand_python
,
'prim::TupleUnpack'
:
tupleunpack_python
,
'prim::ListUnpack'
:
tupleunpack_python
,
'prim::NumToTensor'
:
num2tensor_python
,
'prim::GetAttr'
:
getattr_python
'prim::GetAttr'
:
getattr_python
,
'prim::Constant'
:
constant_python
,
'aten::constant_pad_nd'
:
constant_pad_nd_python
,
'aten::silu'
:
silu_python
,
'aten::expand_as'
:
expandas_python
}
...
...
nni/compression/pytorch/utils/utils.py
View file @
97d067e6
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
re
import
logging
import
torch
torch_float_dtype
=
[
torch
.
float
,
torch
.
float16
,
torch
.
float32
,
torch
.
float64
,
torch
.
half
,
torch
.
double
]
torch_integer_dtype
=
[
torch
.
uint8
,
torch
.
int16
,
torch
.
short
,
torch
.
int32
,
torch
.
long
,
torch
.
bool
]
torch_float_dtype
=
[
torch
.
float
,
torch
.
float16
,
torch
.
float32
,
torch
.
float64
,
torch
.
half
,
torch
.
double
]
torch_integer_dtype
=
[
torch
.
uint8
,
torch
.
int16
,
torch
.
short
,
torch
.
int32
,
torch
.
long
,
torch
.
bool
]
_logger
=
logging
.
getLogger
(
__name__
)
_logger
.
setLevel
(
logging
.
INFO
)
def
get_module_by_name
(
model
,
module_name
):
"""
...
...
@@ -46,11 +54,13 @@ def rand_like_with_shape(shape, ori_t):
require_grad
=
ori_t
.
requires_grad
lower_bound
=
torch
.
min
(
ori_t
)
higher_bound
=
torch
.
max
(
ori_t
)
if
dtype
in
[
torch
.
uint8
,
torch
.
int16
,
torch
.
short
,
torch
.
int16
,
torch
.
long
,
torch
.
bool
]:
return
torch
.
randint
(
lower_bound
,
higher_bound
+
1
,
shape
,
dtype
=
dtype
,
device
=
device
)
else
:
return
torch
.
rand
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
require_grad
)
def
randomize_tensor
(
tensor
,
start
=
1
,
end
=
100
):
"""
Randomize the target tensor according to the given
...
...
@@ -59,11 +69,79 @@ def randomize_tensor(tensor, start=1, end=100):
assert
isinstance
(
tensor
,
torch
.
Tensor
)
if
tensor
.
dtype
in
torch_integer_dtype
:
# integer tensor can only be randomized by the torch.randint
# torch.randint(int(start), int(end), tensor.size(), out=tensor.data, dtype=tensor.dtype)
pass
torch
.
randint
(
int
(
start
),
int
(
end
),
tensor
.
size
(),
out
=
tensor
.
data
,
dtype
=
tensor
.
dtype
)
# pass
else
:
# we can use nn.init.uniform_ to randomize this tensor
# Note: the tensor that with integer type cannot be randomize
# with nn.init.uniform_
torch
.
nn
.
init
.
uniform_
(
tensor
.
data
,
start
,
end
)
jit_python_code_replacement
=
{
'torch.slice'
:
lambda
tmpstr
:
python_slice_replace
(
tmpstr
)
}
def
translate_jit_code
(
code
):
pattern
=
'torch\.(.*?)\('
func_names
=
re
.
findall
(
pattern
,
code
)
modules
=
{
'torch.'
:
torch
,
'torch.nn.functional.'
:
torch
.
nn
.
functional
,
'torch.Tensor.'
:
torch
.
Tensor
,
'torch._C._nn.'
:
torch
.
_C
.
_nn
}
replace
=
{}
# rebase the namespace to get the runnable python code
for
full_name
in
func_names
:
func
=
re
.
split
(
'\.'
,
full_name
)[
-
1
]
for
module_name
in
modules
:
torch_module
=
modules
[
module_name
]
if
hasattr
(
torch_module
,
func
):
replace
[
'torch.'
+
full_name
]
=
module_name
+
func
break
# assert found == True, 'Cannot find the function call %s' % full_name
for
key
,
value
in
replace
.
items
():
code
=
code
.
replace
(
key
,
value
)
# several function cannot find the coresponding function under the namespace
# torch.Tensor and torch.(for example torch.slice), so we need to handle these
# functions manually
lines
=
code
.
split
(
'
\n
'
)
for
i
,
line
in
enumerate
(
lines
):
for
fname
in
jit_python_code_replacement
:
if
fname
in
line
:
lines
[
i
]
=
jit_python_code_replacement
[
fname
](
line
)
code
=
'
\n
'
.
join
(
lines
)
code
=
'import torch
\n
from torch import Tensor, tensor
\n
from typing import *
\n
'
+
code
with
open
(
'nni_jit_tmp_forward.py'
,
'w'
)
as
f
:
f
.
write
(
code
)
from
nni_jit_tmp_forward
import
forward
# pylint: disable=import-error
return
forward
def
python_slice_replace
(
funcstr
):
"""
translate the torch.slice to the appropriate python str that can be replace
in the forward function string.
Parameters
----------
funcstr: str
the str that calling the torch.slice, for example:
_8 = torch.slice(attention_mask, 0, 0, 9223372036854775807, 1)
Returns:
new_str: str
the string that should replace the original one
"""
# parse the input parameters
pattern
=
'torch\.slice\((.*)\)'
parameter_str
=
re
.
findall
(
pattern
,
funcstr
)
parameters
=
re
.
split
(
','
,
parameter_str
[
0
])
target_tensor
=
parameters
[
0
]
dim
=
int
(
parameters
[
1
])
dim_str
=
','
.
join
([
':'
]
*
(
dim
)
+
[
':'
.
join
(
parameters
[
2
:])])
print
(
'%s[%s]'
%
(
target_tensor
,
dim_str
))
new_str
=
funcstr
.
replace
(
'torch.slice(%s)'
%
parameter_str
[
0
],
'%s[%s]'
%
(
target_tensor
,
dim_str
))
return
new_str
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment