Unverified Commit c53be963 authored by Louis-J's avatar Louis-J Committed by GitHub
Browse files

fix(speedup): re-write aten schema parser to support pytorch versions < 1.9.0 (#5138)

parent 860ad5cf
......@@ -11,6 +11,7 @@ if TYPE_CHECKING: # Only imports the below statements during type checking
from nni.common.graph_utils import NodePyGroup
import re
import string
import logging
from functools import partial, lru_cache
import copy
......@@ -394,10 +395,11 @@ schema_fix_dict = {
# ce=None, bool? pin_memory=None) -> (Tensor"""'
}
@lru_cache(maxsize=256)
@lru_cache
def parse_aten_schema(schema: str):
"""
Parse the schema, to positional_num and keyword_list, and detect if the argument should be specially treated.
only available on pytorch >= v1.9.0
"""
if schema in schema_fix_dict:
schema = schema_fix_dict[schema]
......@@ -422,6 +424,266 @@ def parse_aten_schema(schema: str):
return positional_num, keyword_list, special_treat
@lru_cache
def parse_aten_schema_version_1_8_x(schema: str):
"""
Parse the schema, to positional_num and keyword_list, and detect if the argument should be specially treated.
Cannot use 'torch._C.parse_schema' because 'torch._C.Argument' has no 'kwarg_only' on pytorch v1.8.x
Using a lexer-parser like method to parse it.
Re-write from torch/csrc/jit/frontend/function_schema_parser.cpp
"""
if schema in schema_fix_dict:
schema = schema_fix_dict[schema]
single_solid_tokens = [
'(', ')', '[', ']',
'+', '-', '!', '>',
'|', '=', ':', '.', ',',
'?', '*',
]
# no '>=', '<=', '&', '/'
# '|' only occurs in 'Tensor(a|b)'
spec_tokens = [
'numdigits', 'string', 'quoted', 'unknown',
]
str_chars_first = (*string.ascii_letters, '_')
str_chars = (*string.ascii_letters, *string.digits, '_')
num_chars_first = (*string.digits,)
num_chars_16 = (*string.digits, *string.ascii_lowercase[:6], *string.ascii_uppercase[:6])
tokens = list()
# 1: in ('\'', '"'); 2: in num; 3: in str;
status = 0
status_esc_char = False
for char in schema:
if status == 1:
if status_esc_char:
status_esc_char = False
tokens[-1][1] += char
elif char == '\\':
status_esc_char = True
else:
tokens[-1][1] += char
if char == tokens[-1][1][0]:
status = 0
continue
elif status == 2:
if char in num_chars_16:
tokens[-1][1] += char
continue
else:
status = 0
elif status == 3:
if char in str_chars:
tokens[-1][1] += char
continue
else:
status = 0
if status == 0:
if char in single_solid_tokens:
tokens.append(char)
elif char in ('\'', '\"'):
tokens.append(['quoted', char])
status = 1
elif char in num_chars_first:
tokens.append(['numdigits', char])
status = 2
elif char in str_chars_first:
tokens.append(['string', char])
status = 3
elif char not in ('\n', ' ', '\t'):
tokens.append(['unknown', char])
assert status == 0
index = 0
def next_pass(index_diff = 1) -> str:
nonlocal index
index += index_diff
if index_diff == 1:
return tokens[index - 1]
def next_if(tk: str, index_diff=0) -> bool:
nonlocal index
if tk in spec_tokens:
return isinstance(tokens[index + index_diff], list) and tokens[index + index_diff][0] == tk
else:
return tokens[index + index_diff] == tk
def next_if_pass_value(tk: str, default_value = None) -> Optional[str]:
nonlocal index
if tk in spec_tokens:
if isinstance(tokens[index], list) and tokens[index][0] == tk:
index += 1
return tokens[index - 1][1]
else:
if tokens[index] == tk:
index += 1
return tk
return default_value
def next_expect_pass_value(tk: str) -> str:
nonlocal index
if tk in spec_tokens:
if not isinstance(tokens[index], list) or tokens[index][0] != tk:
raise Exception('aten schema parse error')
ret = tokens[index][1]
else:
if tokens[index] != tk:
raise Exception('aten schema parse error')
ret = tk
index += 1
return ret
def parse_number():
if next_if('+') or next_if('-'):
value = next_pass() + next_expect_pass_value('numdigits')
elif (get := next_if_pass_value('numdigits')) is not None:
value = get
else:
return None
if next_if_pass_value('.') is not None:
value += '.'
if (get := next_if_pass_value('numdigits')):
value += get
if value[-1] == 'e' and next_if_pass_value('-') is not None:
# only occur in versions < 1.9.0
# 1e-10
value += '-' + next_expect_pass_value('numdigits')
return value
def parse_name():
name = next_expect_pass_value('string')
if next_if_pass_value(':') is not None:
next_expect_pass_value(':')
name += '::' + next_expect_pass_value('string')
overload_name = ''
if next_if_pass_value('.') is not None:
overload_name = next_expect_pass_value('string')
return name, overload_name
def parse_list(sep, end, callback):
ret = []
if end is None or not next_if(end):
ret.append(callback())
while (get := next_if_pass_value(sep)) is not None:
ret.append(get)
ret.append(callback())
if end is not None:
ret.append(next_expect_pass_value(end))
return ret
def parse_alias_annotation():
if next_if_pass_value('(') is not None:
def parse_inner():
if next_if_pass_value('*') is not None:
return '*'
else:
return next_expect_pass_value('string')
value = '('.join(parse_list('|', None, parse_inner))
value += next_if_pass_value('!', '')
if next_if('-') and next_if('>', 1):
next_pass(2)
value += '->'
value += ''.join(parse_list('|', None, parse_inner))
return value + next_expect_pass_value(')')
else:
return next_if_pass_value('!', '')
def parse_type():
if next_if_pass_value('(') is not None:
value = ''.join(parse_list(',', ')', parse_type))
else:
value = next_expect_pass_value('string')
if value == '__torch__':
# only occur in versions < 1.9.0
while (get := next_if_pass_value('.')) is not None:
value += get + next_expect_pass_value('string')
if next_if_pass_value('('):
the_types = ''.join(parse_list(',', ')', parse_type))
value += '(%s)' % the_types
value += parse_alias_annotation()
while True:
if next_if('[') and next_if(']', 1):
next_pass(2)
value += '[]'
value += parse_alias_annotation()
elif next_if_pass_value('?') is not None:
value += '?'
elif next_if_pass_value('-') is not None:
# only occur in versions < 1.9.0
# t(x -> *)
value += '-' + next_expect_pass_value('>') + next_expect_pass_value('*')
break
else:
break
return value
def parse_default_value():
if next_if_pass_value('[') is not None:
return parse_list(',', ']', parse_default_value)
elif (get := parse_number()) is not None:
return get
elif (get := next_if_pass_value('quoted')) is not None:
return get
else:
return next_expect_pass_value('string')
def parse_argument():
the_type = parse_type()
if next_if_pass_value('[') is not None:
the_type += '[' + parse_number() + next_expect_pass_value(']')
the_type += parse_alias_annotation()
the_type += next_if_pass_value('?', '')
name = next_expect_pass_value('string')
default_value = ''
if next_if_pass_value('=') is not None:
default_value = parse_default_value()
return the_type, name, default_value
def parse_declaration():
name, overload_name = parse_name()
arguments = list()
kwarg_only = False
is_vararg = False
next_expect_pass_value('(')
def parse_inner():
nonlocal kwarg_only
nonlocal is_vararg
if is_vararg:
raise Exception('"..." must be the last element')
elif next_if_pass_value('*') is not None:
kwarg_only = True
elif next_if_pass_value('.') is not None:
next_expect_pass_value('.')
next_expect_pass_value('.')
is_vararg = True
else:
arguments.append((parse_argument()[1], kwarg_only))
parse_list(',', ')', parse_inner)
return name, overload_name, arguments, is_vararg
positional_num = 0
keyword_list = list()
special_treat = dict() # for dtype and memory_format trans now
for name, kwarg_only in parse_declaration()[2]:
if not kwarg_only:
key = positional_num
positional_num += 1
else:
key = name
keyword_list.append(key)
if name in special_treat_dict:
if key not in special_treat:
special_treat[key] = [special_treat_dict[name]]
else:
special_treat[key].append(special_treat_dict[name])
return positional_num, keyword_list, special_treat
def parse_input_value(speedup: ModelSpeedup, input_nodes: List[torch._C.Node], positional_num: int, keyword_list: List[str]):
"""
translate inputs, to constant positional arguments, constant keyword arguments, and undetermined positions
......@@ -486,7 +748,10 @@ def generate_aten_to_python(func: Callable, node: NodePyGroup, speedup: ModelSpe
c_node = node.key_node
schema = c_node.schema()
positional_num, keyword_list, special_treat = parse_aten_schema(schema)
if torch.__version__ < '1.9.0':
positional_num, keyword_list, special_treat = parse_aten_schema_version_1_8_x(schema)
else:
positional_num, keyword_list, special_treat = parse_aten_schema(schema)
input_nodes = list(c_node.inputs())
positional, keyword, undetermined = parse_input_value(speedup, input_nodes, positional_num, keyword_list)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import unittest
import torch
from nni.compression.pytorch.speedup.jit_translate import parse_aten_schema_version_1_8_x, schema_fix_dict, special_treat_dict
def parse_aten_schema_origin(schema: str):
if schema in schema_fix_dict:
schema = schema_fix_dict[schema]
positional_num = 0
keyword_list = list()
special_treat = dict() # for dtype and memory_format trans now
for arg in torch._C.parse_schema(schema).arguments:
if torch.__version__ < '1.9.0' or not arg.kwarg_only:
key = positional_num
positional_num += 1
else:
key = arg.name
keyword_list.append(key)
if arg.name in special_treat_dict:
if key not in special_treat:
special_treat[key] = [special_treat_dict[arg.name]]
else:
special_treat[key].append(special_treat_dict[arg.name])
return positional_num, keyword_list, special_treat
class SchemaParserTestCase(unittest.TestCase):
def test_diff_manual_parser(self):
all_schema_list = (str(i) for i in torch._C._jit_get_all_schemas())
for schema in all_schema_list:
if not schema.startswith('aten::'):
continue
if torch.__version__ < '1.9.0' and '*,' in schema:
continue
positional_num_origin, keyword_list_origin, special_treat_origin = parse_aten_schema_origin(schema)
positional_num_manual, keyword_list_manual, special_treat_manual = parse_aten_schema_version_1_8_x(schema)
assert positional_num_origin == positional_num_manual
assert keyword_list_origin == keyword_list_manual
assert special_treat_origin == special_treat_manual
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment