Unverified Commit b168b016 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Fix visualization for ENAS micro (#2813)

parent 8c6a6407
......@@ -162,7 +162,7 @@ class Mutator(BaseMutator):
if self._connect_all:
return self._all_connect_tensor_reduction(mutable.reduction,
[op(*args, **kwargs) for op in mutable]), \
torch.ones(len(mutable))
torch.ones(len(mutable)).bool()
def _map_fn(op, args, kwargs):
return op(*args, **kwargs)
......@@ -192,7 +192,7 @@ class Mutator(BaseMutator):
"""
if self._connect_all:
return self._all_connect_tensor_reduction(mutable.reduction, tensor_list), \
torch.ones(mutable.n_candidates)
torch.ones(mutable.n_candidates).bool()
mask = self._get_decision(mutable)
assert len(mask) == mutable.n_candidates, \
"Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment