"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "85cbe589a7c9c9b687b9d8790b84b0119eab9514"
Commit 6940fa9c authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

Allow specifying extra lightning trainer params via `_DEFAULTS_` in yaml

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/461

There're needs for extending trainer parameters that are not in (or conflict with) the base d2go config, this diff adds a way to inject those configs without touching the base d2go config.
- In `get_trainer_params`, it simply checks the `LIGHTNING_TRAINER` and use whatever configs under it.
- Adds `GeneralizedRCNNTaskNoDefaultConfig`, which allows specify default config via yaml file for `GeneralizedRCNNTask`. (also make some changes for prerequisite)
- (next diff) User can add their own config updater by registering it in `CONFIG_UPDATER_REGISTRY`.

Differential Revision: D42928992

fbshipit-source-id: f2a1d8a3f2bec9908bb1af03928611d963b92c0e
parent eb184a78
...@@ -15,6 +15,7 @@ from d2go.modeling.meta_arch.fcos import add_fcos_configs ...@@ -15,6 +15,7 @@ from d2go.modeling.meta_arch.fcos import add_fcos_configs
from d2go.modeling.model_freezing_utils import add_model_freezing_configs from d2go.modeling.model_freezing_utils import add_model_freezing_configs
from d2go.modeling.subclass import add_subclass_configs from d2go.modeling.subclass import add_subclass_configs
from d2go.quantization.modeling import add_quantization_default_configs from d2go.quantization.modeling import add_quantization_default_configs
from d2go.registry.builtin import CONFIG_UPDATER_REGISTRY
from d2go.trainer.fsdp import add_fsdp_configs from d2go.trainer.fsdp import add_fsdp_configs
from d2go.utils.visualization import add_tensorboard_default_configs from d2go.utils.visualization import add_tensorboard_default_configs
from detectron2.config import get_cfg as get_d2_cfg from detectron2.config import get_cfg as get_d2_cfg
...@@ -122,6 +123,7 @@ def _add_rcnn_default_config(_C: CN) -> None: ...@@ -122,6 +123,7 @@ def _add_rcnn_default_config(_C: CN) -> None:
_C.register_deprecated_key("RCNN_PREPARE_FOR_QUANT_CONVERT") _C.register_deprecated_key("RCNN_PREPARE_FOR_QUANT_CONVERT")
@CONFIG_UPDATER_REGISTRY.register("BaseRunner")
def get_base_runner_default_cfg(cfg: CN) -> CN: def get_base_runner_default_cfg(cfg: CN) -> CN:
assert len(cfg) == 0, f"start from scratch, but previous cfg is non-empty: {cfg}" assert len(cfg) == 0, f"start from scratch, but previous cfg is non-empty: {cfg}"
...@@ -141,6 +143,7 @@ def get_base_runner_default_cfg(cfg: CN) -> CN: ...@@ -141,6 +143,7 @@ def get_base_runner_default_cfg(cfg: CN) -> CN:
return cfg return cfg
@CONFIG_UPDATER_REGISTRY.register("Detectron2GoRunner")
def get_detectron2go_runner_default_cfg(cfg: CN) -> CN: def get_detectron2go_runner_default_cfg(cfg: CN) -> CN:
assert len(cfg) == 0, f"start from scratch, but previous cfg is non-empty: {cfg}" assert len(cfg) == 0, f"start from scratch, but previous cfg is non-empty: {cfg}"
...@@ -150,6 +153,7 @@ def get_detectron2go_runner_default_cfg(cfg: CN) -> CN: ...@@ -150,6 +153,7 @@ def get_detectron2go_runner_default_cfg(cfg: CN) -> CN:
return cfg return cfg
@CONFIG_UPDATER_REGISTRY.register("GeneralizedRCNNRunner")
def get_generalized_rcnn_runner_default_cfg(cfg: CN) -> CN: def get_generalized_rcnn_runner_default_cfg(cfg: CN) -> CN:
assert len(cfg) == 0, f"start from scratch, but previous cfg is non-empty: {cfg}" assert len(cfg) == 0, f"start from scratch, but previous cfg is non-empty: {cfg}"
......
...@@ -15,6 +15,7 @@ from d2go.data.utils import update_cfg_if_using_adhoc_dataset ...@@ -15,6 +15,7 @@ from d2go.data.utils import update_cfg_if_using_adhoc_dataset
from d2go.modeling.api import build_meta_arch from d2go.modeling.api import build_meta_arch
from d2go.modeling.model_freezing_utils import set_requires_grad from d2go.modeling.model_freezing_utils import set_requires_grad
from d2go.optimizer import build_optimizer_mapper from d2go.optimizer import build_optimizer_mapper
from d2go.runner.api import RunnerV2Mixin
from d2go.runner.callbacks.quantization import maybe_prepare_for_quantization, PREPARED from d2go.runner.callbacks.quantization import maybe_prepare_for_quantization, PREPARED
from d2go.runner.default_runner import ( from d2go.runner.default_runner import (
_get_tbx_writer, _get_tbx_writer,
...@@ -472,3 +473,12 @@ class GeneralizedRCNNTask(DefaultTask): ...@@ -472,3 +473,12 @@ class GeneralizedRCNNTask(DefaultTask):
@classmethod @classmethod
def get_default_cfg(cls): def get_default_cfg(cls):
return GeneralizedRCNNRunner.get_default_cfg() return GeneralizedRCNNRunner.get_default_cfg()
# TODO(T123654122): subclass of DefaultTask will be refactored
class GeneralizedRCNNTaskNoDefaultConfig(RunnerV2Mixin, DefaultTask):
"""
Similar to `GeneralizedRCNNTask` but allowing specifying default config in yaml via `_defaults_`
"""
pass
...@@ -281,8 +281,8 @@ def setup_after_launch( ...@@ -281,8 +281,8 @@ def setup_after_launch(
# save the diff config # save the diff config
default_cfg = ( default_cfg = (
runner.get_default_cfg() runner_class.get_default_cfg()
if runner and not isinstance(runner, RunnerV2Mixin) if runner_class and not issubclass(runner_class, RunnerV2Mixin)
else cfg.get_default_cfg() else cfg.get_default_cfg()
) )
dump_cfg( dump_cfg(
......
...@@ -95,6 +95,19 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]: ...@@ -95,6 +95,19 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
params["gradient_clip_val"] = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE params["gradient_clip_val"] = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
params["gradient_clip_algorithm"] = cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE params["gradient_clip_algorithm"] = cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE
# Allow specifying additional trainer parameters under `LIGHTNING_TRAINER` field.
# Please note that:
# - the `LIGHTNING_TRAINER`` is not part of "base" config, users need to add this to their default config via `_DEFAULTS_` or `get_default_cfg`.
# - this is a temporal solution due to future refactor of config system.
if hasattr(cfg, "LIGHTNING_TRAINER"):
params.update(
{
"reload_dataloaders_every_n_epochs": cfg.LIGHTNING_TRAINER.RELOAD_DATALOADERS_EVERY_N_EPOCHS,
"sync_batchnorm": cfg.LIGHTNING_TRAINER.SYNC_BATCHNORM,
"benchmark": cfg.LIGHTNING_TRAINER.BENCHMARK,
}
)
return params return params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment