Commit c9211c19 authored by Matthew Yu's avatar Matthew Yu Committed by Facebook GitHub Bot
Browse files

support overriding a teacher config where model device is gpu

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/415

The user can build a teacher by providing a trained config. However this model may have been trained using gpu whereas the user wants to load the model on cpu, this diff supports this use case by allowing the user to specify `cfg.DISTILLATION.TEACHER.DEVICE` as override.

Reviewed By: sstsai-adl

Differential Revision: D40125236

fbshipit-source-id: f1fd797a155e12b31bb7fcbc5e4997ee8eb23539
parent cc3e0e4d
......@@ -315,6 +315,14 @@ def _build_teacher(cfg) -> nn.Module:
from d2go.runner import import_runner
from d2go.setup import create_cfg_from_cli
# teacher config may be set to cuda
# if user wants to run teacher on cpu only machine by specifying teacher.device,
# need to override device to cpu before building model
if cfg.DISTILLATION.TEACHER.DEVICE:
cfg.DISTILLATION.TEACHER.OVERWRITE_OPTS.extend(
["MODEL.DEVICE", cfg.DISTILLATION.TEACHER.DEVICE]
)
teacher_cfg = create_cfg_from_cli(
cfg.DISTILLATION.TEACHER.CONFIG_FNAME,
cfg.DISTILLATION.TEACHER.OVERWRITE_OPTS,
......
......@@ -187,6 +187,29 @@ class TestDistillation(unittest.TestCase):
model = _build_teacher(cfg)
self.assertEqual(gt_model.weight, model.weight)
def test_override_teacher_config_gpu_on_cpu(self):
"""Teacher cuda model can be run on cpu if specified in config"""
# build model where teacher is specified on gpu but user overrides cpu
cfg = _get_default_cfg()
cfg.MODEL.META_ARCHITECTURE = "TestMetaArchAddRand"
gt_model = BaseRunner().build_model(cfg)
with make_temp_directory("tmp") as output_dir:
# save model
checkpointer = DetectionCheckpointer(gt_model, save_dir=output_dir)
checkpointer.save("checkpoint")
cfg.MODEL.WEIGHTS = f"{output_dir}/checkpoint.pth"
cfg.MODEL.DEVICE = "cuda"
config_fname = f"{output_dir}/config.yaml"
with PathManager.open(config_fname, "w") as f:
f.write(cfg.dump())
# load model and compare to gt
cfg.DISTILLATION.TEACHER.TYPE = "config"
cfg.DISTILLATION.TEACHER.CONFIG_FNAME = config_fname
cfg.DISTILLATION.TEACHER.DEVICE = "cpu"
model = _build_teacher(cfg)
self.assertEqual(gt_model.weight, model.weight)
class TestPseudoLabeler(unittest.TestCase):
def test_noop(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment