Commit 39f08a33 authored by Tsahi Glik's avatar Tsahi Glik Committed by Facebook GitHub Bot
Browse files

Improve supernet training

Summary:
Add support for apply "sandwitch rule" in supernet training, which uses the max and min subnets in each training step.

Also fix some issues with config logging when using SearchSpace.

Reviewed By: zhanghang1989

Differential Revision: D33676025

fbshipit-source-id: 753a1509bf592e7470ada360815447a3f52d06c7
parent 5a068943
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import contextlib import contextlib
import copy
import logging import logging
from typing import List from typing import List
...@@ -57,8 +57,9 @@ class CfgNode(_CfgNode): ...@@ -57,8 +57,9 @@ class CfgNode(_CfgNode):
return res return res
def dump(self, *args, **kwargs): def dump(self, *args, **kwargs):
self._run_custom_processing(is_dump=True) cfg = copy.deepcopy(self)
return super().dump(*args, **kwargs) cfg._run_custom_processing(is_dump=True)
return super(CfgNode, cfg).dump(*args, **kwargs)
@staticmethod @staticmethod
def load_yaml_with_base(filename: str, *args, **kwargs): def load_yaml_with_base(filename: str, *args, **kwargs):
......
...@@ -170,7 +170,7 @@ class DefaultTask(pl.LightningModule): ...@@ -170,7 +170,7 @@ class DefaultTask(pl.LightningModule):
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
if hasattr(self.model, "training_step"): if hasattr(self.model, "training_step"):
self._meta_arch_training_step(batch, batch_idx) return self._meta_arch_training_step(batch, batch_idx)
return self._standard_training_step(batch, batch_idx) return self._standard_training_step(batch, batch_idx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment