Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
0fe03cd8
Unverified
Commit
0fe03cd8
authored
Aug 05, 2021
by
Yuge Zhang
Committed by
GitHub
Aug 05, 2021
Browse files
Make model wrapper work on graph engine (#4017)
parent
b463e001
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
4 deletions
+5
-4
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+2
-2
nni/retiarii/serializer.py
nni/retiarii/serializer.py
+3
-2
No files found.
nni/retiarii/converter/graph_gen.py
View file @
0fe03cd8
...
@@ -602,9 +602,9 @@ class GraphConverter:
...
@@ -602,9 +602,9 @@ class GraphConverter:
elif
module
.
__class__
.
__module__
.
startswith
(
'torch.nn'
)
and
original_type_name
in
torch
.
nn
.
__dict__
:
elif
module
.
__class__
.
__module__
.
startswith
(
'torch.nn'
)
and
original_type_name
in
torch
.
nn
.
__dict__
:
# this is a basic module from pytorch, no need to parse its graph
# this is a basic module from pytorch, no need to parse its graph
m_attrs
=
get_init_parameters_or_fail
(
module
)
m_attrs
=
get_init_parameters_or_fail
(
module
)
else
:
el
if
getattr
(
module
,
'_stop_parsing'
,
Fal
se
)
:
# this module is marked as serialize, won't continue to parse
# this module is marked as serialize, won't continue to parse
m_attrs
=
get_init_parameters_or_fail
(
module
,
silently
=
True
)
m_attrs
=
get_init_parameters_or_fail
(
module
)
if
m_attrs
is
not
None
:
if
m_attrs
is
not
None
:
return
None
,
m_attrs
return
None
,
m_attrs
...
...
nni/retiarii/serializer.py
View file @
0fe03cd8
...
@@ -83,9 +83,10 @@ class Translatable(abc.ABC):
...
@@ -83,9 +83,10 @@ class Translatable(abc.ABC):
pass
pass
def
_create_wrapper_cls
(
cls
,
store_init_parameters
=
True
,
reset_mutation_uid
=
False
):
def
_create_wrapper_cls
(
cls
,
store_init_parameters
=
True
,
reset_mutation_uid
=
False
,
stop_parsing
=
True
):
class
wrapper
(
cls
):
class
wrapper
(
cls
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
_stop_parsing
=
stop_parsing
if
reset_mutation_uid
:
if
reset_mutation_uid
:
reset_uid
(
'mutation'
)
reset_uid
(
'mutation'
)
if
store_init_parameters
:
if
store_init_parameters
:
...
@@ -163,4 +164,4 @@ def model_wrapper(cls):
...
@@ -163,4 +164,4 @@ def model_wrapper(cls):
1. Capture the init parameters of python class so that it can be re-instantiated in another process.
1. Capture the init parameters of python class so that it can be re-instantiated in another process.
2. Reset uid in `mutation` namespace so that each model counts from zero. Can be useful in unittest and other multi-model scenarios.
2. Reset uid in `mutation` namespace so that each model counts from zero. Can be useful in unittest and other multi-model scenarios.
"""
"""
return
_create_wrapper_cls
(
cls
,
reset_mutation_uid
=
True
)
return
_create_wrapper_cls
(
cls
,
reset_mutation_uid
=
True
,
stop_parsing
=
False
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment