Unverified Commit e7fccfb4 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

TF NAS fix: avoid checking member during forward (#2781)


Co-authored-by: default avatarliuzhe <zhliu1@microsoft.com>
parent 5623dbf3
...@@ -136,10 +136,10 @@ class EnasTrainer: ...@@ -136,10 +136,10 @@ class EnasTrainer:
meters = AverageMeterGroup() meters = AverageMeterGroup()
for x, y in test_loader: for x, y in test_loader:
self.mutator.reset() self.mutator.reset()
logits = self.model(x) logits = self.model(x, training=False)
if isinstance(logits, tuple): if isinstance(logits, tuple):
logits, _ = logits logits, _ = logits
metrics = self.metrics(logits, y) metrics = self.metrics(y, logits)
loss = self.loss(y, logits) loss = self.loss(y, logits)
metrics['loss'] = tf.reduce_mean(loss).numpy() metrics['loss'] = tf.reduce_mean(loss).numpy()
meters.update(metrics) meters.update(metrics)
...@@ -151,8 +151,8 @@ class EnasTrainer: ...@@ -151,8 +151,8 @@ class EnasTrainer:
def _create_train_loader(self): def _create_train_loader(self):
train_set = self.train_set.shuffle(1000000).repeat().batch(self.batch_size) train_set = self.train_set.shuffle(1000000).repeat().batch(self.batch_size)
test_set = self.test_set.shuffle(1000000).repeat().batch(self.batch_size) test_set = self.valid_set.shuffle(1000000).repeat().batch(self.batch_size)
return iter(train_set), iter(test_set) return iter(train_set), iter(test_set)
def _create_validate_loader(self): def _create_validate_loader(self):
return iter(self.test_set.shuffle(1000000).repeat().batch(self.batch_size)) return iter(self.test_set.shuffle(1000000).batch(self.batch_size))
...@@ -28,20 +28,19 @@ class Mutable(Model): ...@@ -28,20 +28,19 @@ class Mutable(Model):
def __deepcopy__(self, memodict=None): def __deepcopy__(self, memodict=None):
raise NotImplementedError("Deep copy doesn't work for mutables.") raise NotImplementedError("Deep copy doesn't work for mutables.")
def __call__(self, *args, **kwargs):
self._check_built()
return super().__call__(*args, **kwargs)
def set_mutator(self, mutator): def set_mutator(self, mutator):
if 'mutator' in self.__dict__: if hasattr(self, 'mutator'):
raise RuntimeError('`set_mutator is called more than once. ' raise RuntimeError('`set_mutator is called more than once. '
'Did you parse the search space multiple times? ' 'Did you parse the search space multiple times? '
'Or did you apply multiple fixed architectures?') 'Or did you apply multiple fixed architectures?')
self.__dict__['mutator'] = mutator self.mutator = mutator
def call(self, *inputs): def call(self, *inputs):
raise NotImplementedError('Method `call` of Mutable must be overridden') raise NotImplementedError('Method `call` of Mutable must be overridden')
def build(self, input_shape):
self._check_built()
@property @property
def key(self): def key(self):
return self._key return self._key
...@@ -68,7 +67,6 @@ class Mutable(Model): ...@@ -68,7 +67,6 @@ class Mutable(Model):
class MutableScope(Mutable): class MutableScope(Mutable):
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
try: try:
self._check_built()
self.mutator.enter_mutable_scope(self) self.mutator.enter_mutable_scope(self)
return super().__call__(*args, **kwargs) return super().__call__(*args, **kwargs)
finally: finally:
...@@ -80,7 +78,7 @@ class LayerChoice(Mutable): ...@@ -80,7 +78,7 @@ class LayerChoice(Mutable):
super().__init__(key=key) super().__init__(key=key)
self.names = [] self.names = []
if isinstance(op_candidates, OrderedDict): if isinstance(op_candidates, OrderedDict):
for name, _ in op_candidates.items(): for name in op_candidates:
assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \ assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \
"Please don't use a reserved name '{}' for your module.".format(name) "Please don't use a reserved name '{}' for your module.".format(name)
self.names.append(name) self.names.append(name)
...@@ -94,21 +92,18 @@ class LayerChoice(Mutable): ...@@ -94,21 +92,18 @@ class LayerChoice(Mutable):
self.choices = op_candidates self.choices = op_candidates
self.reduction = reduction self.reduction = reduction
self.return_mask = return_mask self.return_mask = return_mask
self._built = False
def call(self, *inputs): def call(self, *inputs):
if not self._built:
for op in self.choices:
if len(inputs) > 1: # FIXME: not tested
op.build([inp.shape for inp in inputs])
elif len(inputs) == 1:
op.build(inputs[0].shape)
self._built = True
out, mask = self.mutator.on_forward_layer_choice(self, *inputs) out, mask = self.mutator.on_forward_layer_choice(self, *inputs)
if self.return_mask: if self.return_mask:
return out, mask return out, mask
return out return out
def build(self, input_shape):
self._check_built()
for op in self.choices:
op.build(input_shape)
def __len__(self): def __len__(self):
return len(self.choices) return len(self.choices)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment