Commit ee9602a1 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

toy example of training model for turing

Summary:
Add toy example to illustrate the Turing workflow.
- modify the model building, add converting to helios step. Note that we need to hide this from OSS, so create FB version of the runner, in order to modify `build_model` and `get_default_cfg`.
- make the `D2 (https://github.com/facebookresearch/d2go/commit/7992f91324aee6ae59795063a007c6837e60cdb8)GoCompatibleMNISTRunner` up-to-date, and use the "tutorial" meta-arch for writing unit test since it's the simplest model. Note that even `TutorialNet` is very simple, there's still a constraint that the FC has to run on 4D tensor with 1x1 spatial dimension because it's been mapped to 1x1 Conv by Helios, modify the `TutorialNet` to make it compatible.

Reviewed By: newstzpz

Differential Revision: D31705305

fbshipit-source-id: 77949dfbf08252be5495e9273210274c8ad86abb
parent 274d3b49
...@@ -232,7 +232,8 @@ class Detectron2GoRunner(BaseRunner): ...@@ -232,7 +232,8 @@ class Detectron2GoRunner(BaseRunner):
_C = super(Detectron2GoRunner, Detectron2GoRunner).get_default_cfg() _C = super(Detectron2GoRunner, Detectron2GoRunner).get_default_cfg()
return get_default_cfg(_C) return get_default_cfg(_C)
def build_model(self, cfg, eval_only=False): # temporary API
def _build_model(self, cfg, eval_only=False):
# build_model might modify the cfg, thus clone # build_model might modify the cfg, thus clone
cfg = cfg.clone() cfg = cfg.clone()
...@@ -266,6 +267,11 @@ class Detectron2GoRunner(BaseRunner): ...@@ -266,6 +267,11 @@ class Detectron2GoRunner(BaseRunner):
if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY: if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY:
model_ema.apply_model_ema(model) model_ema.apply_model_ema(model)
return model
def build_model(self, cfg, eval_only=False):
model = self._build_model(cfg, eval_only)
# Note: the _visualize_model API is experimental # Note: the _visualize_model API is experimental
if comm.is_main_process(): if comm.is_main_process():
if hasattr(model, "_visualize_model"): if hasattr(model, "_visualize_model"):
...@@ -602,15 +608,18 @@ class Detectron2GoRunner(BaseRunner): ...@@ -602,15 +608,18 @@ class Detectron2GoRunner(BaseRunner):
return QATHook(cfg, self.build_detection_train_loader) return QATHook(cfg, self.build_detection_train_loader)
def _add_rcnn_default_config(_C):
_C.EXPORT_CAFFE2 = CfgNode()
_C.EXPORT_CAFFE2.USE_HEATMAP_MAX_KEYPOINT = False
_C.RCNN_PREPARE_FOR_EXPORT = "default_rcnn_prepare_for_export"
_C.RCNN_PREPARE_FOR_QUANT = "default_rcnn_prepare_for_quant"
_C.RCNN_PREPARE_FOR_QUANT_CONVERT = "default_rcnn_prepare_for_quant_convert"
class GeneralizedRCNNRunner(Detectron2GoRunner): class GeneralizedRCNNRunner(Detectron2GoRunner):
@staticmethod @staticmethod
def get_default_cfg(): def get_default_cfg():
_C = super(GeneralizedRCNNRunner, GeneralizedRCNNRunner).get_default_cfg() _C = super(GeneralizedRCNNRunner, GeneralizedRCNNRunner).get_default_cfg()
_C.EXPORT_CAFFE2 = CfgNode() _add_rcnn_default_config(_C)
_C.EXPORT_CAFFE2.USE_HEATMAP_MAX_KEYPOINT = False
_C.RCNN_PREPARE_FOR_EXPORT = "default_rcnn_prepare_for_export"
_C.RCNN_PREPARE_FOR_QUANT = "default_rcnn_prepare_for_quant"
_C.RCNN_PREPARE_FOR_QUANT_CONVERT = "default_rcnn_prepare_for_quant_convert"
return _C return _C
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment