Unverified Commit 0fe03cd8 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Make model wrapper work on graph engine (#4017)

parent b463e001
...@@ -602,9 +602,9 @@ class GraphConverter: ...@@ -602,9 +602,9 @@ class GraphConverter:
elif module.__class__.__module__.startswith('torch.nn') and original_type_name in torch.nn.__dict__: elif module.__class__.__module__.startswith('torch.nn') and original_type_name in torch.nn.__dict__:
# this is a basic module from pytorch, no need to parse its graph # this is a basic module from pytorch, no need to parse its graph
m_attrs = get_init_parameters_or_fail(module) m_attrs = get_init_parameters_or_fail(module)
else: elif getattr(module, '_stop_parsing', False):
# this module is marked as serialize, won't continue to parse # this module is marked as serialize, won't continue to parse
m_attrs = get_init_parameters_or_fail(module, silently=True) m_attrs = get_init_parameters_or_fail(module)
if m_attrs is not None: if m_attrs is not None:
return None, m_attrs return None, m_attrs
......
...@@ -83,9 +83,10 @@ class Translatable(abc.ABC): ...@@ -83,9 +83,10 @@ class Translatable(abc.ABC):
pass pass
def _create_wrapper_cls(cls, store_init_parameters=True, reset_mutation_uid=False): def _create_wrapper_cls(cls, store_init_parameters=True, reset_mutation_uid=False, stop_parsing=True):
class wrapper(cls): class wrapper(cls):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._stop_parsing = stop_parsing
if reset_mutation_uid: if reset_mutation_uid:
reset_uid('mutation') reset_uid('mutation')
if store_init_parameters: if store_init_parameters:
...@@ -163,4 +164,4 @@ def model_wrapper(cls): ...@@ -163,4 +164,4 @@ def model_wrapper(cls):
1. Capture the init parameters of python class so that it can be re-instantiated in another process. 1. Capture the init parameters of python class so that it can be re-instantiated in another process.
2. Reset uid in `mutation` namespace so that each model counts from zero. Can be useful in unittest and other multi-model scenarios. 2. Reset uid in `mutation` namespace so that each model counts from zero. Can be useful in unittest and other multi-model scenarios.
""" """
return _create_wrapper_cls(cls, reset_mutation_uid=True) return _create_wrapper_cls(cls, reset_mutation_uid=True, stop_parsing=False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment