Commit aae8381a authored by Mircea Cimpoi's avatar Mircea Cimpoi Committed by Facebook GitHub Bot
Browse files

One EMAState in D2go 1/N - model_ema.py --> ema.py

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/436

Renaming `model_ema.py` to `ema.py` (as `modeling` is already in the folder name. Fixing dependencies after rename

Reviewed By: wat3rBro

Differential Revision: D41685115

fbshipit-source-id: 006999a020a901ea8be4b71e072d688bd36cdce2
parent 40a6a453
...@@ -3,7 +3,7 @@ import os ...@@ -3,7 +3,7 @@ import os
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import torch import torch
from d2go.modeling.model_ema import EMAState from d2go.modeling.ema import EMAState
from d2go.quantization.modeling import QATCheckpointer from d2go.quantization.modeling import QATCheckpointer
from d2go.trainer.fsdp import FSDPWrapper from d2go.trainer.fsdp import FSDPWrapper
......
...@@ -8,7 +8,7 @@ from d2go.data.build import ( ...@@ -8,7 +8,7 @@ from d2go.data.build import (
add_weighted_training_sampler_default_configs, add_weighted_training_sampler_default_configs,
) )
from d2go.data.config import add_d2go_data_default_configs from d2go.data.config import add_d2go_data_default_configs
from d2go.modeling import kmeans_anchors, model_ema from d2go.modeling import ema, kmeans_anchors
from d2go.modeling.backbone.fbnet_cfg import add_fbnet_v2_default_configs from d2go.modeling.backbone.fbnet_cfg import add_fbnet_v2_default_configs
from d2go.modeling.distillation import add_distillation_configs from d2go.modeling.distillation import add_distillation_configs
from d2go.modeling.meta_arch.fcos import add_fcos_configs from d2go.modeling.meta_arch.fcos import add_fcos_configs
...@@ -38,7 +38,7 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None: ...@@ -38,7 +38,7 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None:
# _C.MODEL.FROZEN_LAYER_REG_EXP # _C.MODEL.FROZEN_LAYER_REG_EXP
add_model_freezing_configs(_C) add_model_freezing_configs(_C)
# _C.MODEL other models # _C.MODEL other models
model_ema.add_model_ema_configs(_C) ema.add_model_ema_configs(_C)
# _C.D2GO_DATA... # _C.D2GO_DATA...
add_d2go_data_default_configs(_C) add_d2go_data_default_configs(_C)
# _C.TENSORBOARD... # _C.TENSORBOARD...
......
...@@ -25,7 +25,7 @@ from d2go.data.utils import ( ...@@ -25,7 +25,7 @@ from d2go.data.utils import (
) )
from d2go.distributed import D2GoSharedContext from d2go.distributed import D2GoSharedContext
from d2go.evaluation.evaluator import inference_on_dataset from d2go.evaluation.evaluator import inference_on_dataset
from d2go.modeling import kmeans_anchors, model_ema from d2go.modeling import ema, kmeans_anchors
from d2go.modeling.api import build_d2go_model from d2go.modeling.api import build_d2go_model
from d2go.modeling.model_freezing_utils import freeze_matched_bn, set_requires_grad from d2go.modeling.model_freezing_utils import freeze_matched_bn, set_requires_grad
from d2go.optimizer import build_optimizer_mapper from d2go.optimizer import build_optimizer_mapper
...@@ -265,7 +265,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -265,7 +265,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
cfg = cfg.clone() cfg = cfg.clone()
model = build_d2go_model(cfg).model model = build_d2go_model(cfg).model
model_ema.may_build_model_ema(cfg, model) ema.may_build_model_ema(cfg, model)
if cfg.QUANTIZATION.QAT.ENABLED: if cfg.QUANTIZATION.QAT.ENABLED:
# Disable fake_quant and observer so that the model will be trained normally # Disable fake_quant and observer so that the model will be trained normally
...@@ -300,7 +300,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -300,7 +300,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
model.eval() model.eval()
if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY: if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY:
model_ema.apply_model_ema(model) ema.apply_model_ema(model)
return model return model
...@@ -318,7 +318,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -318,7 +318,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
return model return model
def build_checkpointer(self, cfg, model, save_dir, **kwargs): def build_checkpointer(self, cfg, model, save_dir, **kwargs):
kwargs.update(model_ema.may_get_ema_checkpointer(cfg, model)) kwargs.update(ema.may_get_ema_checkpointer(cfg, model))
checkpointer = FSDPCheckpointer(model, save_dir=save_dir, **kwargs) checkpointer = FSDPCheckpointer(model, save_dir=save_dir, **kwargs)
return checkpointer return checkpointer
...@@ -443,7 +443,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -443,7 +443,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
# model with ema weights # model with ema weights
if cfg.MODEL_EMA.ENABLED and not isinstance(model, PredictorWrapper): if cfg.MODEL_EMA.ENABLED and not isinstance(model, PredictorWrapper):
logger.info("Run evaluation with EMA.") logger.info("Run evaluation with EMA.")
with model_ema.apply_model_ema_and_restore(model): with ema.apply_model_ema_and_restore(model):
cur_results = self._do_test( cur_results = self._do_test(
new_cfg, model, train_iter=train_iter, model_tag="ema" new_cfg, model, train_iter=train_iter, model_tag="ema"
) )
...@@ -456,7 +456,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -456,7 +456,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
): ):
return [ return [
hooks.IterationTimer(), hooks.IterationTimer(),
model_ema.EMAHook(cfg, model) if cfg.MODEL_EMA.ENABLED else None, ema.EMAHook(cfg, model) if cfg.MODEL_EMA.ENABLED else None,
self._create_data_loader_hook(cfg), self._create_data_loader_hook(cfg),
self._create_after_step_hook( self._create_after_step_hook(
cfg, model, optimizer, scheduler, periodic_checkpointer cfg, model, optimizer, scheduler, periodic_checkpointer
......
...@@ -8,7 +8,7 @@ import unittest ...@@ -8,7 +8,7 @@ import unittest
import d2go.runner.default_runner as default_runner import d2go.runner.default_runner as default_runner
import torch import torch
from d2go.modeling import model_ema from d2go.modeling import ema
from d2go.utils.testing import helper from d2go.utils.testing import helper
...@@ -58,7 +58,7 @@ def _compare_state_dict(model1, model2, abs_error=1e-3): ...@@ -58,7 +58,7 @@ def _compare_state_dict(model1, model2, abs_error=1e-3):
class TestModelingModelEMA(unittest.TestCase): class TestModelingModelEMA(unittest.TestCase):
def test_emastate(self): def test_emastate(self):
model = TestArch() model = TestArch()
state = model_ema.EMAState.FromModel(model) state = ema.EMAState.FromModel(model)
# two for conv (conv.weight, conv.bias), # two for conv (conv.weight, conv.bias),
# five for bn (bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.num_batches_tracked) # five for bn (bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.num_batches_tracked)
self.assertEqual(len(state.state), 7) self.assertEqual(len(state.state), 7)
...@@ -74,12 +74,12 @@ class TestModelingModelEMA(unittest.TestCase): ...@@ -74,12 +74,12 @@ class TestModelingModelEMA(unittest.TestCase):
def test_emastate_saveload(self): def test_emastate_saveload(self):
model = TestArch() model = TestArch()
state = model_ema.EMAState.FromModel(model) state = ema.EMAState.FromModel(model)
model1 = TestArch() model1 = TestArch()
self.assertFalse(_compare_state_dict(model, model1)) self.assertFalse(_compare_state_dict(model, model1))
state1 = model_ema.EMAState() state1 = ema.EMAState()
state1.load_state_dict(state.state_dict()) state1.load_state_dict(state.state_dict())
state1.apply_to(model1) state1.apply_to(model1)
self.assertTrue(_compare_state_dict(model, model1)) self.assertTrue(_compare_state_dict(model, model1))
...@@ -89,7 +89,7 @@ class TestModelingModelEMA(unittest.TestCase): ...@@ -89,7 +89,7 @@ class TestModelingModelEMA(unittest.TestCase):
model = TestArch() model = TestArch()
model.cuda() model.cuda()
# state on gpu # state on gpu
state = model_ema.EMAState.FromModel(model) state = ema.EMAState.FromModel(model)
self.assertEqual(state.device, torch.device("cuda:0")) self.assertEqual(state.device, torch.device("cuda:0"))
# target model on cpu # target model on cpu
model1 = TestArch() model1 = TestArch()
...@@ -98,7 +98,7 @@ class TestModelingModelEMA(unittest.TestCase): ...@@ -98,7 +98,7 @@ class TestModelingModelEMA(unittest.TestCase):
self.assertTrue(_compare_state_dict(copy.deepcopy(model).cpu(), model1)) self.assertTrue(_compare_state_dict(copy.deepcopy(model).cpu(), model1))
# state on cpu # state on cpu
state1 = model_ema.EMAState.FromModel(model, device="cpu") state1 = ema.EMAState.FromModel(model, device="cpu")
self.assertEqual(state1.device, torch.device("cpu")) self.assertEqual(state1.device, torch.device("cpu"))
# target model on gpu # target model on gpu
model2 = TestArch() model2 = TestArch()
...@@ -109,11 +109,11 @@ class TestModelingModelEMA(unittest.TestCase): ...@@ -109,11 +109,11 @@ class TestModelingModelEMA(unittest.TestCase):
def test_ema_updater(self): def test_ema_updater(self):
model = TestArch() model = TestArch()
state = model_ema.EMAState() state = ema.EMAState()
updated_model = TestArch() updated_model = TestArch()
updater = model_ema.EMAUpdater(state, decay=0.0) updater = ema.EMAUpdater(state, decay=0.0)
updater.init_state(model) updater.init_state(model)
for _ in range(3): for _ in range(3):
cur = TestArch() cur = TestArch()
...@@ -122,7 +122,7 @@ class TestModelingModelEMA(unittest.TestCase): ...@@ -122,7 +122,7 @@ class TestModelingModelEMA(unittest.TestCase):
# weight decay == 0.0, always use new model # weight decay == 0.0, always use new model
self.assertTrue(_compare_state_dict(updated_model, cur)) self.assertTrue(_compare_state_dict(updated_model, cur))
updater = model_ema.EMAUpdater(state, decay=1.0) updater = ema.EMAUpdater(state, decay=1.0)
updater.init_state(model) updater.init_state(model)
for _ in range(3): for _ in range(3):
cur = TestArch() cur = TestArch()
...@@ -132,9 +132,9 @@ class TestModelingModelEMA(unittest.TestCase): ...@@ -132,9 +132,9 @@ class TestModelingModelEMA(unittest.TestCase):
self.assertTrue(_compare_state_dict(updated_model, model)) self.assertTrue(_compare_state_dict(updated_model, model))
def test_ema_updater_decay(self): def test_ema_updater_decay(self):
state = model_ema.EMAState() state = ema.EMAState()
updater = model_ema.EMAUpdater(state, decay=0.7) updater = ema.EMAUpdater(state, decay=0.7)
updater.init_state(TestArch(1.0)) updater.init_state(TestArch(1.0))
gt_val = 1.0 gt_val = 1.0
gt_val_int = 1 gt_val_int = 1
...@@ -158,17 +158,17 @@ class TestModelingModelEMAHook(unittest.TestCase): ...@@ -158,17 +158,17 @@ class TestModelingModelEMAHook(unittest.TestCase):
cfg.MODEL_EMA.DECAY = 0.0 cfg.MODEL_EMA.DECAY = 0.0
model = TestArch() model = TestArch()
model_ema.may_build_model_ema(cfg, model) ema.may_build_model_ema(cfg, model)
self.assertTrue(hasattr(model, "ema_state")) self.assertTrue(hasattr(model, "ema_state"))
ema_hook = model_ema.EMAHook(cfg, model) ema_hook = ema.EMAHook(cfg, model)
ema_hook.before_train() ema_hook.before_train()
ema_hook.before_step() ema_hook.before_step()
model.set_const_weights(2.0) model.set_const_weights(2.0)
ema_hook.after_step() ema_hook.after_step()
ema_hook.after_train() ema_hook.after_train()
ema_checkpointers = model_ema.may_get_ema_checkpointer(cfg, model) ema_checkpointers = ema.may_get_ema_checkpointer(cfg, model)
self.assertEqual(len(ema_checkpointers), 1) self.assertEqual(len(ema_checkpointers), 1)
out_model = TestArch() out_model = TestArch()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment