Unverified Commit f42e53d2 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Replace layer choice with selected module after applied fixed architecture (#2420)

parent ef90e7d2
......@@ -176,7 +176,7 @@ For example,
}
```
After applying, the model is then fixed and ready for final training. The model works as a single model, although it might contain more parameters than expected. This comes with pros and cons. The good side is, you can directly load the checkpoint dumped from supernet during the search phase and start retraining from there. However, this is also a model with redundant parameters and this may cause problems when trying to count the number of parameters in the model. For deeper reasons and possible workarounds, see [Trainers](./NasReference.md).
After applying, the model is then fixed and ready for final training. The model works as a single model, and unused parameters and modules are pruned.
Also, refer to [DARTS](./DARTS.md) for code exemplifying retraining.
......
......@@ -2,12 +2,16 @@
# Licensed under the MIT license.
import json
import logging
from .mutables import InputChoice, LayerChoice, MutableScope
from .mutator import Mutator
from .utils import to_list
_logger = logging.getLogger(__name__)
class FixedArchitecture(Mutator):
"""
Fixed architecture mutator that always selects a certain graph.
......@@ -73,6 +77,41 @@ class FixedArchitecture(Mutator):
"""
return self._fixed_arc
def replace_layer_choice(self, module=None, prefix=""):
"""
Replace layer choices with selected candidates. It's done with best effort.
In case of weighted choices or multiple choices. if some of the choices on weighted with zero, delete them.
If single choice, replace the module with a normal module.
Parameters
----------
module : nn.Module
Module to be processed.
prefix : str
Module name under global namespace.
"""
if module is None:
module = self.model
for name, mutable in module.named_children():
global_name = (prefix + "." if prefix else "") + name
if isinstance(mutable, LayerChoice):
chosen = self._fixed_arc[mutable.key]
if sum(chosen) == 1 and max(chosen) == 1 and not mutable.return_mask:
# sum is one, max is one, there has to be an only one
# this is compatible with both integer arrays, boolean arrays and float arrays
_logger.info("Replacing %s with candidate number %d.", global_name, chosen.index(1))
setattr(module, name, mutable[chosen.index(1)])
else:
if mutable.return_mask:
_logger.info("`return_mask` flag of %s is true. As it relies on the behavior of LayerChoice, " \
"LayerChoice will not be replaced.")
# remove unused parameters
for ch, n in zip(chosen, mutable.names):
if ch == 0 and not isinstance(ch, float):
setattr(mutable, n, None)
else:
self.replace_layer_choice(mutable, global_name)
def apply_fixed_architecture(model, fixed_arc):
"""
......@@ -96,4 +135,7 @@ def apply_fixed_architecture(model, fixed_arc):
fixed_arc = json.load(f)
architecture = FixedArchitecture(model, fixed_arc)
architecture.reset()
# for the convenience of parameters counting
architecture.replace_layer_choice()
return architecture
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment