Commit aae8381a authored by Mircea Cimpoi's avatar Mircea Cimpoi Committed by Facebook GitHub Bot
Browse files

One EMAState in D2go 1/N - model_ema.py --> ema.py

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/436

Renaming `model_ema.py` to `ema.py` (as `modeling` is already in the folder name. Fixing dependencies after rename

Reviewed By: wat3rBro

Differential Revision: D41685115

fbshipit-source-id: 006999a020a901ea8be4b71e072d688bd36cdce2
parent 40a6a453
......@@ -3,7 +3,7 @@ import os
import detectron2.utils.comm as comm
import torch
from d2go.modeling.model_ema import EMAState
from d2go.modeling.ema import EMAState
from d2go.quantization.modeling import QATCheckpointer
from d2go.trainer.fsdp import FSDPWrapper
......
......@@ -8,7 +8,7 @@ from d2go.data.build import (
add_weighted_training_sampler_default_configs,
)
from d2go.data.config import add_d2go_data_default_configs
from d2go.modeling import kmeans_anchors, model_ema
from d2go.modeling import ema, kmeans_anchors
from d2go.modeling.backbone.fbnet_cfg import add_fbnet_v2_default_configs
from d2go.modeling.distillation import add_distillation_configs
from d2go.modeling.meta_arch.fcos import add_fcos_configs
......@@ -38,7 +38,7 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None:
# _C.MODEL.FROZEN_LAYER_REG_EXP
add_model_freezing_configs(_C)
# _C.MODEL other models
model_ema.add_model_ema_configs(_C)
ema.add_model_ema_configs(_C)
# _C.D2GO_DATA...
add_d2go_data_default_configs(_C)
# _C.TENSORBOARD...
......
......@@ -25,7 +25,7 @@ from d2go.data.utils import (
)
from d2go.distributed import D2GoSharedContext
from d2go.evaluation.evaluator import inference_on_dataset
from d2go.modeling import kmeans_anchors, model_ema
from d2go.modeling import ema, kmeans_anchors
from d2go.modeling.api import build_d2go_model
from d2go.modeling.model_freezing_utils import freeze_matched_bn, set_requires_grad
from d2go.optimizer import build_optimizer_mapper
......@@ -265,7 +265,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
cfg = cfg.clone()
model = build_d2go_model(cfg).model
model_ema.may_build_model_ema(cfg, model)
ema.may_build_model_ema(cfg, model)
if cfg.QUANTIZATION.QAT.ENABLED:
# Disable fake_quant and observer so that the model will be trained normally
......@@ -300,7 +300,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
model.eval()
if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY:
model_ema.apply_model_ema(model)
ema.apply_model_ema(model)
return model
......@@ -318,7 +318,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
return model
def build_checkpointer(self, cfg, model, save_dir, **kwargs):
kwargs.update(model_ema.may_get_ema_checkpointer(cfg, model))
kwargs.update(ema.may_get_ema_checkpointer(cfg, model))
checkpointer = FSDPCheckpointer(model, save_dir=save_dir, **kwargs)
return checkpointer
......@@ -443,7 +443,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
# model with ema weights
if cfg.MODEL_EMA.ENABLED and not isinstance(model, PredictorWrapper):
logger.info("Run evaluation with EMA.")
with model_ema.apply_model_ema_and_restore(model):
with ema.apply_model_ema_and_restore(model):
cur_results = self._do_test(
new_cfg, model, train_iter=train_iter, model_tag="ema"
)
......@@ -456,7 +456,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
):
return [
hooks.IterationTimer(),
model_ema.EMAHook(cfg, model) if cfg.MODEL_EMA.ENABLED else None,
ema.EMAHook(cfg, model) if cfg.MODEL_EMA.ENABLED else None,
self._create_data_loader_hook(cfg),
self._create_after_step_hook(
cfg, model, optimizer, scheduler, periodic_checkpointer
......
......@@ -8,7 +8,7 @@ import unittest
import d2go.runner.default_runner as default_runner
import torch
from d2go.modeling import model_ema
from d2go.modeling import ema
from d2go.utils.testing import helper
......@@ -58,7 +58,7 @@ def _compare_state_dict(model1, model2, abs_error=1e-3):
class TestModelingModelEMA(unittest.TestCase):
def test_emastate(self):
model = TestArch()
state = model_ema.EMAState.FromModel(model)
state = ema.EMAState.FromModel(model)
# two for conv (conv.weight, conv.bias),
# five for bn (bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.num_batches_tracked)
self.assertEqual(len(state.state), 7)
......@@ -74,12 +74,12 @@ class TestModelingModelEMA(unittest.TestCase):
def test_emastate_saveload(self):
model = TestArch()
state = model_ema.EMAState.FromModel(model)
state = ema.EMAState.FromModel(model)
model1 = TestArch()
self.assertFalse(_compare_state_dict(model, model1))
state1 = model_ema.EMAState()
state1 = ema.EMAState()
state1.load_state_dict(state.state_dict())
state1.apply_to(model1)
self.assertTrue(_compare_state_dict(model, model1))
......@@ -89,7 +89,7 @@ class TestModelingModelEMA(unittest.TestCase):
model = TestArch()
model.cuda()
# state on gpu
state = model_ema.EMAState.FromModel(model)
state = ema.EMAState.FromModel(model)
self.assertEqual(state.device, torch.device("cuda:0"))
# target model on cpu
model1 = TestArch()
......@@ -98,7 +98,7 @@ class TestModelingModelEMA(unittest.TestCase):
self.assertTrue(_compare_state_dict(copy.deepcopy(model).cpu(), model1))
# state on cpu
state1 = model_ema.EMAState.FromModel(model, device="cpu")
state1 = ema.EMAState.FromModel(model, device="cpu")
self.assertEqual(state1.device, torch.device("cpu"))
# target model on gpu
model2 = TestArch()
......@@ -109,11 +109,11 @@ class TestModelingModelEMA(unittest.TestCase):
def test_ema_updater(self):
model = TestArch()
state = model_ema.EMAState()
state = ema.EMAState()
updated_model = TestArch()
updater = model_ema.EMAUpdater(state, decay=0.0)
updater = ema.EMAUpdater(state, decay=0.0)
updater.init_state(model)
for _ in range(3):
cur = TestArch()
......@@ -122,7 +122,7 @@ class TestModelingModelEMA(unittest.TestCase):
# weight decay == 0.0, always use new model
self.assertTrue(_compare_state_dict(updated_model, cur))
updater = model_ema.EMAUpdater(state, decay=1.0)
updater = ema.EMAUpdater(state, decay=1.0)
updater.init_state(model)
for _ in range(3):
cur = TestArch()
......@@ -132,9 +132,9 @@ class TestModelingModelEMA(unittest.TestCase):
self.assertTrue(_compare_state_dict(updated_model, model))
def test_ema_updater_decay(self):
state = model_ema.EMAState()
state = ema.EMAState()
updater = model_ema.EMAUpdater(state, decay=0.7)
updater = ema.EMAUpdater(state, decay=0.7)
updater.init_state(TestArch(1.0))
gt_val = 1.0
gt_val_int = 1
......@@ -158,17 +158,17 @@ class TestModelingModelEMAHook(unittest.TestCase):
cfg.MODEL_EMA.DECAY = 0.0
model = TestArch()
model_ema.may_build_model_ema(cfg, model)
ema.may_build_model_ema(cfg, model)
self.assertTrue(hasattr(model, "ema_state"))
ema_hook = model_ema.EMAHook(cfg, model)
ema_hook = ema.EMAHook(cfg, model)
ema_hook.before_train()
ema_hook.before_step()
model.set_const_weights(2.0)
ema_hook.after_step()
ema_hook.after_train()
ema_checkpointers = model_ema.may_get_ema_checkpointer(cfg, model)
ema_checkpointers = ema.may_get_ema_checkpointer(cfg, model)
self.assertEqual(len(ema_checkpointers), 1)
out_model = TestArch()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment