Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
c53be963
Unverified
Commit
c53be963
authored
Oct 18, 2022
by
Louis-J
Committed by
GitHub
Oct 18, 2022
Browse files
fix(speedup): re-write aten schema parser to support pytorch versions < 1.9.0 (#5138)
parent
860ad5cf
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
317 additions
and
2 deletions
+317
-2
nni/compression/pytorch/speedup/jit_translate.py
nni/compression/pytorch/speedup/jit_translate.py
+267
-2
test/algo/compression/v2/test_schema_parser.py
test/algo/compression/v2/test_schema_parser.py
+50
-0
No files found.
nni/compression/pytorch/speedup/jit_translate.py
View file @
c53be963
...
@@ -11,6 +11,7 @@ if TYPE_CHECKING: # Only imports the below statements during type checking
...
@@ -11,6 +11,7 @@ if TYPE_CHECKING: # Only imports the below statements during type checking
from
nni.common.graph_utils
import
NodePyGroup
from
nni.common.graph_utils
import
NodePyGroup
import
re
import
re
import
string
import
logging
import
logging
from
functools
import
partial
,
lru_cache
from
functools
import
partial
,
lru_cache
import
copy
import
copy
...
@@ -394,10 +395,11 @@ schema_fix_dict = {
...
@@ -394,10 +395,11 @@ schema_fix_dict = {
# ce=None, bool? pin_memory=None) -> (Tensor"""'
# ce=None, bool? pin_memory=None) -> (Tensor"""'
}
}
@
lru_cache
(
maxsize
=
256
)
@
lru_cache
def
parse_aten_schema
(
schema
:
str
):
def
parse_aten_schema
(
schema
:
str
):
"""
"""
Parse the schema, to positional_num and keyword_list, and detect if the argument should be specially treated.
Parse the schema, to positional_num and keyword_list, and detect if the argument should be specially treated.
only available on pytorch >= v1.9.0
"""
"""
if
schema
in
schema_fix_dict
:
if
schema
in
schema_fix_dict
:
schema
=
schema_fix_dict
[
schema
]
schema
=
schema_fix_dict
[
schema
]
...
@@ -422,6 +424,266 @@ def parse_aten_schema(schema: str):
...
@@ -422,6 +424,266 @@ def parse_aten_schema(schema: str):
return
positional_num
,
keyword_list
,
special_treat
return
positional_num
,
keyword_list
,
special_treat
@
lru_cache
def
parse_aten_schema_version_1_8_x
(
schema
:
str
):
"""
Parse the schema, to positional_num and keyword_list, and detect if the argument should be specially treated.
Cannot use 'torch._C.parse_schema' because 'torch._C.Argument' has no 'kwarg_only' on pytorch v1.8.x
Using a lexer-parser like method to parse it.
Re-write from torch/csrc/jit/frontend/function_schema_parser.cpp
"""
if
schema
in
schema_fix_dict
:
schema
=
schema_fix_dict
[
schema
]
single_solid_tokens
=
[
'('
,
')'
,
'['
,
']'
,
'+'
,
'-'
,
'!'
,
'>'
,
'|'
,
'='
,
':'
,
'.'
,
','
,
'?'
,
'*'
,
]
# no '>=', '<=', '&', '/'
# '|' only occurs in 'Tensor(a|b)'
spec_tokens
=
[
'numdigits'
,
'string'
,
'quoted'
,
'unknown'
,
]
str_chars_first
=
(
*
string
.
ascii_letters
,
'_'
)
str_chars
=
(
*
string
.
ascii_letters
,
*
string
.
digits
,
'_'
)
num_chars_first
=
(
*
string
.
digits
,)
num_chars_16
=
(
*
string
.
digits
,
*
string
.
ascii_lowercase
[:
6
],
*
string
.
ascii_uppercase
[:
6
])
tokens
=
list
()
# 1: in ('\'', '"'); 2: in num; 3: in str;
status
=
0
status_esc_char
=
False
for
char
in
schema
:
if
status
==
1
:
if
status_esc_char
:
status_esc_char
=
False
tokens
[
-
1
][
1
]
+=
char
elif
char
==
'
\\
'
:
status_esc_char
=
True
else
:
tokens
[
-
1
][
1
]
+=
char
if
char
==
tokens
[
-
1
][
1
][
0
]:
status
=
0
continue
elif
status
==
2
:
if
char
in
num_chars_16
:
tokens
[
-
1
][
1
]
+=
char
continue
else
:
status
=
0
elif
status
==
3
:
if
char
in
str_chars
:
tokens
[
-
1
][
1
]
+=
char
continue
else
:
status
=
0
if
status
==
0
:
if
char
in
single_solid_tokens
:
tokens
.
append
(
char
)
elif
char
in
(
'
\'
'
,
'
\"
'
):
tokens
.
append
([
'quoted'
,
char
])
status
=
1
elif
char
in
num_chars_first
:
tokens
.
append
([
'numdigits'
,
char
])
status
=
2
elif
char
in
str_chars_first
:
tokens
.
append
([
'string'
,
char
])
status
=
3
elif
char
not
in
(
'
\n
'
,
' '
,
'
\t
'
):
tokens
.
append
([
'unknown'
,
char
])
assert
status
==
0
index
=
0
def
next_pass
(
index_diff
=
1
)
->
str
:
nonlocal
index
index
+=
index_diff
if
index_diff
==
1
:
return
tokens
[
index
-
1
]
def
next_if
(
tk
:
str
,
index_diff
=
0
)
->
bool
:
nonlocal
index
if
tk
in
spec_tokens
:
return
isinstance
(
tokens
[
index
+
index_diff
],
list
)
and
tokens
[
index
+
index_diff
][
0
]
==
tk
else
:
return
tokens
[
index
+
index_diff
]
==
tk
def
next_if_pass_value
(
tk
:
str
,
default_value
=
None
)
->
Optional
[
str
]:
nonlocal
index
if
tk
in
spec_tokens
:
if
isinstance
(
tokens
[
index
],
list
)
and
tokens
[
index
][
0
]
==
tk
:
index
+=
1
return
tokens
[
index
-
1
][
1
]
else
:
if
tokens
[
index
]
==
tk
:
index
+=
1
return
tk
return
default_value
def
next_expect_pass_value
(
tk
:
str
)
->
str
:
nonlocal
index
if
tk
in
spec_tokens
:
if
not
isinstance
(
tokens
[
index
],
list
)
or
tokens
[
index
][
0
]
!=
tk
:
raise
Exception
(
'aten schema parse error'
)
ret
=
tokens
[
index
][
1
]
else
:
if
tokens
[
index
]
!=
tk
:
raise
Exception
(
'aten schema parse error'
)
ret
=
tk
index
+=
1
return
ret
def
parse_number
():
if
next_if
(
'+'
)
or
next_if
(
'-'
):
value
=
next_pass
()
+
next_expect_pass_value
(
'numdigits'
)
elif
(
get
:
=
next_if_pass_value
(
'numdigits'
))
is
not
None
:
value
=
get
else
:
return
None
if
next_if_pass_value
(
'.'
)
is
not
None
:
value
+=
'.'
if
(
get
:
=
next_if_pass_value
(
'numdigits'
)):
value
+=
get
if
value
[
-
1
]
==
'e'
and
next_if_pass_value
(
'-'
)
is
not
None
:
# only occur in versions < 1.9.0
# 1e-10
value
+=
'-'
+
next_expect_pass_value
(
'numdigits'
)
return
value
def
parse_name
():
name
=
next_expect_pass_value
(
'string'
)
if
next_if_pass_value
(
':'
)
is
not
None
:
next_expect_pass_value
(
':'
)
name
+=
'::'
+
next_expect_pass_value
(
'string'
)
overload_name
=
''
if
next_if_pass_value
(
'.'
)
is
not
None
:
overload_name
=
next_expect_pass_value
(
'string'
)
return
name
,
overload_name
def
parse_list
(
sep
,
end
,
callback
):
ret
=
[]
if
end
is
None
or
not
next_if
(
end
):
ret
.
append
(
callback
())
while
(
get
:
=
next_if_pass_value
(
sep
))
is
not
None
:
ret
.
append
(
get
)
ret
.
append
(
callback
())
if
end
is
not
None
:
ret
.
append
(
next_expect_pass_value
(
end
))
return
ret
def
parse_alias_annotation
():
if
next_if_pass_value
(
'('
)
is
not
None
:
def
parse_inner
():
if
next_if_pass_value
(
'*'
)
is
not
None
:
return
'*'
else
:
return
next_expect_pass_value
(
'string'
)
value
=
'('
.
join
(
parse_list
(
'|'
,
None
,
parse_inner
))
value
+=
next_if_pass_value
(
'!'
,
''
)
if
next_if
(
'-'
)
and
next_if
(
'>'
,
1
):
next_pass
(
2
)
value
+=
'->'
value
+=
''
.
join
(
parse_list
(
'|'
,
None
,
parse_inner
))
return
value
+
next_expect_pass_value
(
')'
)
else
:
return
next_if_pass_value
(
'!'
,
''
)
def
parse_type
():
if
next_if_pass_value
(
'('
)
is
not
None
:
value
=
''
.
join
(
parse_list
(
','
,
')'
,
parse_type
))
else
:
value
=
next_expect_pass_value
(
'string'
)
if
value
==
'__torch__'
:
# only occur in versions < 1.9.0
while
(
get
:
=
next_if_pass_value
(
'.'
))
is
not
None
:
value
+=
get
+
next_expect_pass_value
(
'string'
)
if
next_if_pass_value
(
'('
):
the_types
=
''
.
join
(
parse_list
(
','
,
')'
,
parse_type
))
value
+=
'(%s)'
%
the_types
value
+=
parse_alias_annotation
()
while
True
:
if
next_if
(
'['
)
and
next_if
(
']'
,
1
):
next_pass
(
2
)
value
+=
'[]'
value
+=
parse_alias_annotation
()
elif
next_if_pass_value
(
'?'
)
is
not
None
:
value
+=
'?'
elif
next_if_pass_value
(
'-'
)
is
not
None
:
# only occur in versions < 1.9.0
# t(x -> *)
value
+=
'-'
+
next_expect_pass_value
(
'>'
)
+
next_expect_pass_value
(
'*'
)
break
else
:
break
return
value
def
parse_default_value
():
if
next_if_pass_value
(
'['
)
is
not
None
:
return
parse_list
(
','
,
']'
,
parse_default_value
)
elif
(
get
:
=
parse_number
())
is
not
None
:
return
get
elif
(
get
:
=
next_if_pass_value
(
'quoted'
))
is
not
None
:
return
get
else
:
return
next_expect_pass_value
(
'string'
)
def
parse_argument
():
the_type
=
parse_type
()
if
next_if_pass_value
(
'['
)
is
not
None
:
the_type
+=
'['
+
parse_number
()
+
next_expect_pass_value
(
']'
)
the_type
+=
parse_alias_annotation
()
the_type
+=
next_if_pass_value
(
'?'
,
''
)
name
=
next_expect_pass_value
(
'string'
)
default_value
=
''
if
next_if_pass_value
(
'='
)
is
not
None
:
default_value
=
parse_default_value
()
return
the_type
,
name
,
default_value
def
parse_declaration
():
name
,
overload_name
=
parse_name
()
arguments
=
list
()
kwarg_only
=
False
is_vararg
=
False
next_expect_pass_value
(
'('
)
def
parse_inner
():
nonlocal
kwarg_only
nonlocal
is_vararg
if
is_vararg
:
raise
Exception
(
'"..." must be the last element'
)
elif
next_if_pass_value
(
'*'
)
is
not
None
:
kwarg_only
=
True
elif
next_if_pass_value
(
'.'
)
is
not
None
:
next_expect_pass_value
(
'.'
)
next_expect_pass_value
(
'.'
)
is_vararg
=
True
else
:
arguments
.
append
((
parse_argument
()[
1
],
kwarg_only
))
parse_list
(
','
,
')'
,
parse_inner
)
return
name
,
overload_name
,
arguments
,
is_vararg
positional_num
=
0
keyword_list
=
list
()
special_treat
=
dict
()
# for dtype and memory_format trans now
for
name
,
kwarg_only
in
parse_declaration
()[
2
]:
if
not
kwarg_only
:
key
=
positional_num
positional_num
+=
1
else
:
key
=
name
keyword_list
.
append
(
key
)
if
name
in
special_treat_dict
:
if
key
not
in
special_treat
:
special_treat
[
key
]
=
[
special_treat_dict
[
name
]]
else
:
special_treat
[
key
].
append
(
special_treat_dict
[
name
])
return
positional_num
,
keyword_list
,
special_treat
def
parse_input_value
(
speedup
:
ModelSpeedup
,
input_nodes
:
List
[
torch
.
_C
.
Node
],
positional_num
:
int
,
keyword_list
:
List
[
str
]):
def
parse_input_value
(
speedup
:
ModelSpeedup
,
input_nodes
:
List
[
torch
.
_C
.
Node
],
positional_num
:
int
,
keyword_list
:
List
[
str
]):
"""
"""
translate inputs, to constant positional arguments, constant keyword arguments, and undetermined positions
translate inputs, to constant positional arguments, constant keyword arguments, and undetermined positions
...
@@ -486,7 +748,10 @@ def generate_aten_to_python(func: Callable, node: NodePyGroup, speedup: ModelSpe
...
@@ -486,7 +748,10 @@ def generate_aten_to_python(func: Callable, node: NodePyGroup, speedup: ModelSpe
c_node
=
node
.
key_node
c_node
=
node
.
key_node
schema
=
c_node
.
schema
()
schema
=
c_node
.
schema
()
positional_num
,
keyword_list
,
special_treat
=
parse_aten_schema
(
schema
)
if
torch
.
__version__
<
'1.9.0'
:
positional_num
,
keyword_list
,
special_treat
=
parse_aten_schema_version_1_8_x
(
schema
)
else
:
positional_num
,
keyword_list
,
special_treat
=
parse_aten_schema
(
schema
)
input_nodes
=
list
(
c_node
.
inputs
())
input_nodes
=
list
(
c_node
.
inputs
())
positional
,
keyword
,
undetermined
=
parse_input_value
(
speedup
,
input_nodes
,
positional_num
,
keyword_list
)
positional
,
keyword
,
undetermined
=
parse_input_value
(
speedup
,
input_nodes
,
positional_num
,
keyword_list
)
...
...
test/algo/compression/v2/test_schema_parser.py
0 → 100644
View file @
c53be963
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
unittest
import
torch
from
nni.compression.pytorch.speedup.jit_translate
import
parse_aten_schema_version_1_8_x
,
schema_fix_dict
,
special_treat_dict
def
parse_aten_schema_origin
(
schema
:
str
):
if
schema
in
schema_fix_dict
:
schema
=
schema_fix_dict
[
schema
]
positional_num
=
0
keyword_list
=
list
()
special_treat
=
dict
()
# for dtype and memory_format trans now
for
arg
in
torch
.
_C
.
parse_schema
(
schema
).
arguments
:
if
torch
.
__version__
<
'1.9.0'
or
not
arg
.
kwarg_only
:
key
=
positional_num
positional_num
+=
1
else
:
key
=
arg
.
name
keyword_list
.
append
(
key
)
if
arg
.
name
in
special_treat_dict
:
if
key
not
in
special_treat
:
special_treat
[
key
]
=
[
special_treat_dict
[
arg
.
name
]]
else
:
special_treat
[
key
].
append
(
special_treat_dict
[
arg
.
name
])
return
positional_num
,
keyword_list
,
special_treat
class
SchemaParserTestCase
(
unittest
.
TestCase
):
def
test_diff_manual_parser
(
self
):
all_schema_list
=
(
str
(
i
)
for
i
in
torch
.
_C
.
_jit_get_all_schemas
())
for
schema
in
all_schema_list
:
if
not
schema
.
startswith
(
'aten::'
):
continue
if
torch
.
__version__
<
'1.9.0'
and
'*,'
in
schema
:
continue
positional_num_origin
,
keyword_list_origin
,
special_treat_origin
=
parse_aten_schema_origin
(
schema
)
positional_num_manual
,
keyword_list_manual
,
special_treat_manual
=
parse_aten_schema_version_1_8_x
(
schema
)
assert
positional_num_origin
==
positional_num_manual
assert
keyword_list_origin
==
keyword_list_manual
assert
special_treat_origin
==
special_treat_manual
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment