Unverified Commit b372abf8 authored by HeekangPark's avatar HeekangPark Committed by GitHub
Browse files

Fix Error in nas SPOS trainer, apply_fixed_architecture (#3051)

parent 45e82b3e
...@@ -63,6 +63,7 @@ class SPOSSupernetTrainer(Trainer): ...@@ -63,6 +63,7 @@ class SPOSSupernetTrainer(Trainer):
self.model.train() self.model.train()
meters = AverageMeterGroup() meters = AverageMeterGroup()
for step, (x, y) in enumerate(self.train_loader): for step, (x, y) in enumerate(self.train_loader):
x, y = x.to(self.device), y.to(self.device)
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.mutator.reset() self.mutator.reset()
logits = self.model(x) logits = self.model(x)
...@@ -82,6 +83,7 @@ class SPOSSupernetTrainer(Trainer): ...@@ -82,6 +83,7 @@ class SPOSSupernetTrainer(Trainer):
meters = AverageMeterGroup() meters = AverageMeterGroup()
with torch.no_grad(): with torch.no_grad():
for step, (x, y) in enumerate(self.valid_loader): for step, (x, y) in enumerate(self.valid_loader):
x, y = x.to(self.device), y.to(self.device)
self.mutator.reset() self.mutator.reset()
logits = self.model(x) logits = self.model(x)
loss = self.loss(logits, y) loss = self.loss(logits, y)
......
...@@ -24,11 +24,14 @@ class FixedArchitecture(Mutator): ...@@ -24,11 +24,14 @@ class FixedArchitecture(Mutator):
Preloaded architecture object. Preloaded architecture object.
strict : bool strict : bool
Force everything that appears in ``fixed_arc`` to be used at least once. Force everything that appears in ``fixed_arc`` to be used at least once.
verbose : bool
Print log messages if set to True
""" """
def __init__(self, model, fixed_arc, strict=True): def __init__(self, model, fixed_arc, strict=True, verbose=True):
super().__init__(model) super().__init__(model)
self._fixed_arc = fixed_arc self._fixed_arc = fixed_arc
self.verbose = verbose
mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)]) mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)])
fixed_arc_keys = set(self._fixed_arc.keys()) fixed_arc_keys = set(self._fixed_arc.keys())
...@@ -99,10 +102,11 @@ class FixedArchitecture(Mutator): ...@@ -99,10 +102,11 @@ class FixedArchitecture(Mutator):
if sum(chosen) == 1 and max(chosen) == 1 and not mutable.return_mask: if sum(chosen) == 1 and max(chosen) == 1 and not mutable.return_mask:
# sum is one, max is one, there has to be an only one # sum is one, max is one, there has to be an only one
# this is compatible with both integer arrays, boolean arrays and float arrays # this is compatible with both integer arrays, boolean arrays and float arrays
_logger.info("Replacing %s with candidate number %d.", global_name, chosen.index(1)) if self.verbose:
_logger.info("Replacing %s with candidate number %d.", global_name, chosen.index(1))
setattr(module, name, mutable[chosen.index(1)]) setattr(module, name, mutable[chosen.index(1)])
else: else:
if mutable.return_mask: if mutable.return_mask and self.verbose:
_logger.info("`return_mask` flag of %s is true. As it relies on the behavior of LayerChoice, " \ _logger.info("`return_mask` flag of %s is true. As it relies on the behavior of LayerChoice, " \
"LayerChoice will not be replaced.") "LayerChoice will not be replaced.")
# remove unused parameters # remove unused parameters
...@@ -113,7 +117,7 @@ class FixedArchitecture(Mutator): ...@@ -113,7 +117,7 @@ class FixedArchitecture(Mutator):
self.replace_layer_choice(mutable, global_name) self.replace_layer_choice(mutable, global_name)
def apply_fixed_architecture(model, fixed_arc): def apply_fixed_architecture(model, fixed_arc, verbose=True):
""" """
Load architecture from `fixed_arc` and apply to model. Load architecture from `fixed_arc` and apply to model.
...@@ -123,6 +127,8 @@ def apply_fixed_architecture(model, fixed_arc): ...@@ -123,6 +127,8 @@ def apply_fixed_architecture(model, fixed_arc):
Model with mutables. Model with mutables.
fixed_arc : str or dict fixed_arc : str or dict
Path to the JSON that stores the architecture, or dict that stores the exported architecture. Path to the JSON that stores the architecture, or dict that stores the exported architecture.
verbose : bool
Print log messages if set to True
Returns Returns
------- -------
...@@ -133,7 +139,7 @@ def apply_fixed_architecture(model, fixed_arc): ...@@ -133,7 +139,7 @@ def apply_fixed_architecture(model, fixed_arc):
if isinstance(fixed_arc, str): if isinstance(fixed_arc, str):
with open(fixed_arc) as f: with open(fixed_arc) as f:
fixed_arc = json.load(f) fixed_arc = json.load(f)
architecture = FixedArchitecture(model, fixed_arc) architecture = FixedArchitecture(model, fixed_arc, verbose)
architecture.reset() architecture.reset()
# for the convenience of parameters counting # for the convenience of parameters counting
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment