"vscode:/vscode.git/clone" did not exist on "962d9aee04fbf7af04047d917266b1b40cc6a31a"
Unverified Commit 0850ddd6 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Fix bug when traced kwargs are traced type (#4765)

parent faa6bebe
......@@ -6,7 +6,7 @@ from typing import Any, List, Optional, Tuple, Dict, Iterator
import torch.nn as nn
from nni.common.serializer import is_traceable
from nni.common.serializer import is_traceable, is_wrapped_with_trace
from nni.retiarii.graph import Cell, Graph, Model, ModelStatus, Node, Evaluator
from nni.retiarii.mutator import Mutator
from nni.retiarii.serializer import is_basic_unit, is_model_wrapped
......@@ -361,7 +361,7 @@ class EvaluatorValueChoiceMutator(Mutator):
# we only need one such mutator for one model/evaluator
def _mutate_traceable_object(self, obj: Any, value_choice_decisions: Dict[str, Any]) -> Any:
if not is_traceable(obj):
if not _is_traceable_object(obj):
return obj
updates = {}
......@@ -400,7 +400,7 @@ class EvaluatorValueChoiceMutator(Mutator):
def process_evaluator_mutations(evaluator: Evaluator, existing_mutators: List[Mutator]) -> List[Mutator]:
# take all the value choice in the kwargs of evaluaator into a list
# `existing_mutators` can mutators generated from `model`
if not is_traceable(evaluator):
if not _is_traceable_object(evaluator):
return []
mutator_candidates = {}
for param in _expand_nested_trace_kwargs(evaluator):
......@@ -464,9 +464,12 @@ def _expand_nested_trace_kwargs(obj: Any) -> Iterator[Any]:
# Get items from `trace_kwargs`.
# If some item is traceable itself, get items recursively.
if not is_traceable(obj):
return
if _is_traceable_object(obj):
for param in obj.trace_kwargs.values():
yield param
yield from _expand_nested_trace_kwargs(param)
def _is_traceable_object(obj: Any) -> bool:
# Is it a traceable "object" (not class)?
return is_traceable(obj) and not is_wrapped_with_trace(obj)
......@@ -1229,6 +1229,11 @@ class Shared(unittest.TestCase):
assert len(set(values)) == 3
@unittest.skipIf(pytorch_lightning.__version__ < '1.0', 'Legacy PyTorch-lightning not supported')
def test_valuechoice_classification(self):
evaluator = pl.Classification(criterion=nn.CrossEntropyLoss)
process_evaluator_mutations(evaluator, [])
def test_retiarii_nn_import(self):
dummy = torch.zeros(1, 16, 32, 24)
nn.init.uniform_(dummy)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment