"...git@developer.sourcefind.cn:modelzoo/diffbir_pytorch.git" did not exist on "8e2fc43fd305747cb84128387c579fedd1cb9c70"
Unverified Commit 0850ddd6 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Fix bug when traced kwargs are traced type (#4765)

parent faa6bebe
...@@ -6,7 +6,7 @@ from typing import Any, List, Optional, Tuple, Dict, Iterator ...@@ -6,7 +6,7 @@ from typing import Any, List, Optional, Tuple, Dict, Iterator
import torch.nn as nn import torch.nn as nn
from nni.common.serializer import is_traceable from nni.common.serializer import is_traceable, is_wrapped_with_trace
from nni.retiarii.graph import Cell, Graph, Model, ModelStatus, Node, Evaluator from nni.retiarii.graph import Cell, Graph, Model, ModelStatus, Node, Evaluator
from nni.retiarii.mutator import Mutator from nni.retiarii.mutator import Mutator
from nni.retiarii.serializer import is_basic_unit, is_model_wrapped from nni.retiarii.serializer import is_basic_unit, is_model_wrapped
...@@ -361,7 +361,7 @@ class EvaluatorValueChoiceMutator(Mutator): ...@@ -361,7 +361,7 @@ class EvaluatorValueChoiceMutator(Mutator):
# we only need one such mutator for one model/evaluator # we only need one such mutator for one model/evaluator
def _mutate_traceable_object(self, obj: Any, value_choice_decisions: Dict[str, Any]) -> Any: def _mutate_traceable_object(self, obj: Any, value_choice_decisions: Dict[str, Any]) -> Any:
if not is_traceable(obj): if not _is_traceable_object(obj):
return obj return obj
updates = {} updates = {}
...@@ -400,7 +400,7 @@ class EvaluatorValueChoiceMutator(Mutator): ...@@ -400,7 +400,7 @@ class EvaluatorValueChoiceMutator(Mutator):
def process_evaluator_mutations(evaluator: Evaluator, existing_mutators: List[Mutator]) -> List[Mutator]: def process_evaluator_mutations(evaluator: Evaluator, existing_mutators: List[Mutator]) -> List[Mutator]:
# take all the value choice in the kwargs of evaluaator into a list # take all the value choice in the kwargs of evaluaator into a list
# `existing_mutators` can mutators generated from `model` # `existing_mutators` can mutators generated from `model`
if not is_traceable(evaluator): if not _is_traceable_object(evaluator):
return [] return []
mutator_candidates = {} mutator_candidates = {}
for param in _expand_nested_trace_kwargs(evaluator): for param in _expand_nested_trace_kwargs(evaluator):
...@@ -464,9 +464,12 @@ def _expand_nested_trace_kwargs(obj: Any) -> Iterator[Any]: ...@@ -464,9 +464,12 @@ def _expand_nested_trace_kwargs(obj: Any) -> Iterator[Any]:
# Get items from `trace_kwargs`. # Get items from `trace_kwargs`.
# If some item is traceable itself, get items recursively. # If some item is traceable itself, get items recursively.
if not is_traceable(obj): if _is_traceable_object(obj):
return for param in obj.trace_kwargs.values():
yield param
yield from _expand_nested_trace_kwargs(param)
for param in obj.trace_kwargs.values():
yield param def _is_traceable_object(obj: Any) -> bool:
yield from _expand_nested_trace_kwargs(param) # Is it a traceable "object" (not class)?
return is_traceable(obj) and not is_wrapped_with_trace(obj)
...@@ -1229,6 +1229,11 @@ class Shared(unittest.TestCase): ...@@ -1229,6 +1229,11 @@ class Shared(unittest.TestCase):
assert len(set(values)) == 3 assert len(set(values)) == 3
@unittest.skipIf(pytorch_lightning.__version__ < '1.0', 'Legacy PyTorch-lightning not supported')
def test_valuechoice_classification(self):
evaluator = pl.Classification(criterion=nn.CrossEntropyLoss)
process_evaluator_mutations(evaluator, [])
def test_retiarii_nn_import(self): def test_retiarii_nn_import(self):
dummy = torch.zeros(1, 16, 32, 24) dummy = torch.zeros(1, 16, 32, 24)
nn.init.uniform_(dummy) nn.init.uniform_(dummy)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment