Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
0850ddd6
"vscode:/vscode.git/clone" did not exist on "962d9aee04fbf7af04047d917266b1b40cc6a31a"
Unverified
Commit
0850ddd6
authored
Apr 14, 2022
by
Yuge Zhang
Committed by
GitHub
Apr 14, 2022
Browse files
Fix bug when traced kwargs are traced type (#4765)
parent
faa6bebe
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
8 deletions
+16
-8
nni/retiarii/nn/pytorch/mutator.py
nni/retiarii/nn/pytorch/mutator.py
+11
-8
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+5
-0
No files found.
nni/retiarii/nn/pytorch/mutator.py
View file @
0850ddd6
...
...
@@ -6,7 +6,7 @@ from typing import Any, List, Optional, Tuple, Dict, Iterator
import
torch.nn
as
nn
from
nni.common.serializer
import
is_traceable
from
nni.common.serializer
import
is_traceable
,
is_wrapped_with_trace
from
nni.retiarii.graph
import
Cell
,
Graph
,
Model
,
ModelStatus
,
Node
,
Evaluator
from
nni.retiarii.mutator
import
Mutator
from
nni.retiarii.serializer
import
is_basic_unit
,
is_model_wrapped
...
...
@@ -361,7 +361,7 @@ class EvaluatorValueChoiceMutator(Mutator):
# we only need one such mutator for one model/evaluator
def
_mutate_traceable_object
(
self
,
obj
:
Any
,
value_choice_decisions
:
Dict
[
str
,
Any
])
->
Any
:
if
not
is_traceable
(
obj
):
if
not
_
is_traceable
_object
(
obj
):
return
obj
updates
=
{}
...
...
@@ -400,7 +400,7 @@ class EvaluatorValueChoiceMutator(Mutator):
def
process_evaluator_mutations
(
evaluator
:
Evaluator
,
existing_mutators
:
List
[
Mutator
])
->
List
[
Mutator
]:
# take all the value choice in the kwargs of evaluaator into a list
# `existing_mutators` can mutators generated from `model`
if
not
is_traceable
(
evaluator
):
if
not
_
is_traceable
_object
(
evaluator
):
return
[]
mutator_candidates
=
{}
for
param
in
_expand_nested_trace_kwargs
(
evaluator
):
...
...
@@ -464,9 +464,12 @@ def _expand_nested_trace_kwargs(obj: Any) -> Iterator[Any]:
# Get items from `trace_kwargs`.
# If some item is traceable itself, get items recursively.
if
not
is_traceable
(
obj
):
return
if
_is_traceable_object
(
obj
):
for
param
in
obj
.
trace_kwargs
.
values
():
yield
param
yield
from
_expand_nested_trace_kwargs
(
param
)
def
_is_traceable_object
(
obj
:
Any
)
->
bool
:
# Is it a traceable "object" (not class)?
return
is_traceable
(
obj
)
and
not
is_wrapped_with_trace
(
obj
)
test/ut/retiarii/test_highlevel_apis.py
View file @
0850ddd6
...
...
@@ -1229,6 +1229,11 @@ class Shared(unittest.TestCase):
assert
len
(
set
(
values
))
==
3
@
unittest
.
skipIf
(
pytorch_lightning
.
__version__
<
'1.0'
,
'Legacy PyTorch-lightning not supported'
)
def
test_valuechoice_classification
(
self
):
evaluator
=
pl
.
Classification
(
criterion
=
nn
.
CrossEntropyLoss
)
process_evaluator_mutations
(
evaluator
,
[])
def
test_retiarii_nn_import
(
self
):
dummy
=
torch
.
zeros
(
1
,
16
,
32
,
24
)
nn
.
init
.
uniform_
(
dummy
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment