"vscode:/vscode.git/clone" did not exist on "a5764016739676a55e5e10ae88e081ebf94d2a38"
Unverified Commit 0fe03cd8 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Make model wrapper work on graph engine (#4017)

parent b463e001
......@@ -602,9 +602,9 @@ class GraphConverter:
elif module.__class__.__module__.startswith('torch.nn') and original_type_name in torch.nn.__dict__:
# this is a basic module from pytorch, no need to parse its graph
m_attrs = get_init_parameters_or_fail(module)
else:
elif getattr(module, '_stop_parsing', False):
# this module is marked as serialize, won't continue to parse
m_attrs = get_init_parameters_or_fail(module, silently=True)
m_attrs = get_init_parameters_or_fail(module)
if m_attrs is not None:
return None, m_attrs
......
......@@ -83,9 +83,10 @@ class Translatable(abc.ABC):
pass
def _create_wrapper_cls(cls, store_init_parameters=True, reset_mutation_uid=False):
def _create_wrapper_cls(cls, store_init_parameters=True, reset_mutation_uid=False, stop_parsing=True):
class wrapper(cls):
def __init__(self, *args, **kwargs):
self._stop_parsing = stop_parsing
if reset_mutation_uid:
reset_uid('mutation')
if store_init_parameters:
......@@ -163,4 +164,4 @@ def model_wrapper(cls):
1. Capture the init parameters of python class so that it can be re-instantiated in another process.
2. Reset uid in `mutation` namespace so that each model counts from zero. Can be useful in unittest and other multi-model scenarios.
"""
return _create_wrapper_cls(cls, reset_mutation_uid=True)
return _create_wrapper_cls(cls, reset_mutation_uid=True, stop_parsing=False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment