Unverified Commit b372abf8 authored by HeekangPark's avatar HeekangPark Committed by GitHub
Browse files

Fix Error in nas SPOS trainer, apply_fixed_architecture (#3051)

parent 45e82b3e
......@@ -63,6 +63,7 @@ class SPOSSupernetTrainer(Trainer):
self.model.train()
meters = AverageMeterGroup()
for step, (x, y) in enumerate(self.train_loader):
x, y = x.to(self.device), y.to(self.device)
self.optimizer.zero_grad()
self.mutator.reset()
logits = self.model(x)
......@@ -82,6 +83,7 @@ class SPOSSupernetTrainer(Trainer):
meters = AverageMeterGroup()
with torch.no_grad():
for step, (x, y) in enumerate(self.valid_loader):
x, y = x.to(self.device), y.to(self.device)
self.mutator.reset()
logits = self.model(x)
loss = self.loss(logits, y)
......
......@@ -24,11 +24,14 @@ class FixedArchitecture(Mutator):
Preloaded architecture object.
strict : bool
Force everything that appears in ``fixed_arc`` to be used at least once.
verbose : bool
Print log messages if set to True
"""
def __init__(self, model, fixed_arc, strict=True):
def __init__(self, model, fixed_arc, strict=True, verbose=True):
super().__init__(model)
self._fixed_arc = fixed_arc
self.verbose = verbose
mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)])
fixed_arc_keys = set(self._fixed_arc.keys())
......@@ -99,10 +102,11 @@ class FixedArchitecture(Mutator):
if sum(chosen) == 1 and max(chosen) == 1 and not mutable.return_mask:
# sum is one, max is one, there has to be an only one
# this is compatible with both integer arrays, boolean arrays and float arrays
_logger.info("Replacing %s with candidate number %d.", global_name, chosen.index(1))
if self.verbose:
_logger.info("Replacing %s with candidate number %d.", global_name, chosen.index(1))
setattr(module, name, mutable[chosen.index(1)])
else:
if mutable.return_mask:
if mutable.return_mask and self.verbose:
_logger.info("`return_mask` flag of %s is true. As it relies on the behavior of LayerChoice, " \
"LayerChoice will not be replaced.")
# remove unused parameters
......@@ -113,7 +117,7 @@ class FixedArchitecture(Mutator):
self.replace_layer_choice(mutable, global_name)
def apply_fixed_architecture(model, fixed_arc):
def apply_fixed_architecture(model, fixed_arc, verbose=True):
"""
Load architecture from `fixed_arc` and apply to model.
......@@ -123,6 +127,8 @@ def apply_fixed_architecture(model, fixed_arc):
Model with mutables.
fixed_arc : str or dict
Path to the JSON that stores the architecture, or dict that stores the exported architecture.
verbose : bool
Print log messages if set to True
Returns
-------
......@@ -133,7 +139,7 @@ def apply_fixed_architecture(model, fixed_arc):
if isinstance(fixed_arc, str):
with open(fixed_arc) as f:
fixed_arc = json.load(f)
architecture = FixedArchitecture(model, fixed_arc)
architecture = FixedArchitecture(model, fixed_arc, verbose)
architecture.reset()
# for the convenience of parameters counting
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment