Unverified Commit e7fccfb4 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

TF NAS fix: avoid checking member during forward (#2781)


Co-authored-by: default avatarliuzhe <zhliu1@microsoft.com>
parent 5623dbf3
......@@ -136,10 +136,10 @@ class EnasTrainer:
meters = AverageMeterGroup()
for x, y in test_loader:
self.mutator.reset()
logits = self.model(x)
logits = self.model(x, training=False)
if isinstance(logits, tuple):
logits, _ = logits
metrics = self.metrics(logits, y)
metrics = self.metrics(y, logits)
loss = self.loss(y, logits)
metrics['loss'] = tf.reduce_mean(loss).numpy()
meters.update(metrics)
......@@ -151,8 +151,8 @@ class EnasTrainer:
def _create_train_loader(self):
train_set = self.train_set.shuffle(1000000).repeat().batch(self.batch_size)
test_set = self.test_set.shuffle(1000000).repeat().batch(self.batch_size)
test_set = self.valid_set.shuffle(1000000).repeat().batch(self.batch_size)
return iter(train_set), iter(test_set)
def _create_validate_loader(self):
return iter(self.test_set.shuffle(1000000).repeat().batch(self.batch_size))
return iter(self.test_set.shuffle(1000000).batch(self.batch_size))
......@@ -28,20 +28,19 @@ class Mutable(Model):
def __deepcopy__(self, memodict=None):
raise NotImplementedError("Deep copy doesn't work for mutables.")
def __call__(self, *args, **kwargs):
self._check_built()
return super().__call__(*args, **kwargs)
def set_mutator(self, mutator):
if 'mutator' in self.__dict__:
if hasattr(self, 'mutator'):
raise RuntimeError('`set_mutator is called more than once. '
'Did you parse the search space multiple times? '
'Or did you apply multiple fixed architectures?')
self.__dict__['mutator'] = mutator
self.mutator = mutator
def call(self, *inputs):
raise NotImplementedError('Method `call` of Mutable must be overridden')
def build(self, input_shape):
self._check_built()
@property
def key(self):
return self._key
......@@ -68,7 +67,6 @@ class Mutable(Model):
class MutableScope(Mutable):
def __call__(self, *args, **kwargs):
try:
self._check_built()
self.mutator.enter_mutable_scope(self)
return super().__call__(*args, **kwargs)
finally:
......@@ -80,7 +78,7 @@ class LayerChoice(Mutable):
super().__init__(key=key)
self.names = []
if isinstance(op_candidates, OrderedDict):
for name, _ in op_candidates.items():
for name in op_candidates:
assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \
"Please don't use a reserved name '{}' for your module.".format(name)
self.names.append(name)
......@@ -94,21 +92,18 @@ class LayerChoice(Mutable):
self.choices = op_candidates
self.reduction = reduction
self.return_mask = return_mask
self._built = False
def call(self, *inputs):
if not self._built:
for op in self.choices:
if len(inputs) > 1: # FIXME: not tested
op.build([inp.shape for inp in inputs])
elif len(inputs) == 1:
op.build(inputs[0].shape)
self._built = True
out, mask = self.mutator.on_forward_layer_choice(self, *inputs)
if self.return_mask:
return out, mask
return out
def build(self, input_shape):
self._check_built()
for op in self.choices:
op.build(input_shape)
def __len__(self):
return len(self.choices)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment