Commit 65dad512 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

make get_default_cfg a classmethod

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/293

In order to pass runner during the workflow using "runner name" instead of runner instance, we need to make sure the `get_default_cfg` is not instance method. It can be either staticmethod or classmethod, but I choose classmethod for better inheritance.

code mode using following script:
```
#!/usr/bin/env python3

import json
import os
import subprocess

result = subprocess.check_output("fbgs --json 'def get_default_cfg('", shell=True)
fbgs = json.loads(result)
fbsource_root = os.path.expanduser("~")

def _indent(s):
    return len(s) - len(s.lstrip())

def resolve_instance_method(content):
    lines = content.split("\n")
    for idx, line in enumerate(lines):
        if "def get_default_cfg(self" in line:
            indent = _indent(line)
            # find the class
            for j in range(idx, 0, -1):
                if lines[j].startswith(" " * (indent - 4) + "class "):
                    class_line = lines[j]
                    break
            else:
                raise RuntimeError("Can't find class")
            print("class_line: ", class_line)
            if "Runner" in class_line:
                # check self if not used
                for j in range(idx + 1, len(lines)):
                    if _indent(lines[j]) < indent:
                        break
                    assert "self" not in lines[j], (j, lines[j])
                # update the content
                assert "def get_default_cfg(self)" in line
                lines[idx] = lines[idx].replace(
                    "def get_default_cfg(self)", "def get_default_cfg(cls)"
                )
                lines.insert(idx, " " * indent + "classmethod")
                return "\n".join(lines)
    return content

def resolve_static_method(content):
    lines = content.split("\n")
    for idx, line in enumerate(lines):
        if "def get_default_cfg()" in line:
            indent = _indent(line)
            # find the class
            for j in range(idx, 0, -1):
                if "class " in lines[j]:
                    class_line = lines[j]
                    break
            else:
                print("[WARNING] Can't find class!!!")
                continue
            if "Runner" in class_line:
                # check staticmethod is used
                for j in range(idx, 0, -1):
                    if lines[j] == " " * indent + "staticmethod":
                        staticmethod_line_idx = j
                        break
                else:
                    raise RuntimeError("Can't find staticmethod")
                # update the content
                lines[idx] = lines[idx].replace(
                    "def get_default_cfg()", "def get_default_cfg(cls)"
                )
                lines[staticmethod_line_idx] = " " * indent + "classmethod"
                return "\n".join(lines)
    return content

for result in fbgs["results"]:
    filename = os.path.join(fbsource_root, result["file_name"])
    print(f"processing: {filename}")
    with open(filename) as f:
        content = f.read()
    orig_content = content
    while True:
        old_content = content
        content = resolve_instance_method(content)
        content = resolve_static_method(content)
        if content == old_content:
            break
    if content != orig_content:
        print("Updating ...")
        with open(filename, "w") as f:
            f.write(content)
```

Reviewed By: tglik

Differential Revision: D37059264

fbshipit-source-id: b09d5518f4232de95d8313621468905cf10a731c
parent 8cf2b879
...@@ -14,7 +14,8 @@ from detectron2.utils.file_io import PathManager ...@@ -14,7 +14,8 @@ from detectron2.utils.file_io import PathManager
class DebugRunner(BaseRunner): class DebugRunner(BaseRunner):
def get_default_cfg(self): @classmethod
def get_default_cfg(cls):
_C = super().get_default_cfg() _C = super().get_default_cfg()
# _C.TENSORBOARD... # _C.TENSORBOARD...
......
...@@ -160,8 +160,8 @@ class BaseRunner(object): ...@@ -160,8 +160,8 @@ class BaseRunner(object):
""" """
pass pass
@staticmethod @classmethod
def get_default_cfg(): def get_default_cfg(cls):
""" """
Override `get_default_cfg` for adding non common config. Override `get_default_cfg` for adding non common config.
""" """
...@@ -211,8 +211,8 @@ class Detectron2GoRunner(BaseRunner): ...@@ -211,8 +211,8 @@ class Detectron2GoRunner(BaseRunner):
update_cfg_if_using_adhoc_dataset(cfg) update_cfg_if_using_adhoc_dataset(cfg)
patch_d2_meta_arch() patch_d2_meta_arch()
@staticmethod @classmethod
def get_default_cfg(): def get_default_cfg(cls):
cfg = super(Detectron2GoRunner, Detectron2GoRunner).get_default_cfg() cfg = super(Detectron2GoRunner, Detectron2GoRunner).get_default_cfg()
cfg.PROFILERS = ["default_flop_counter"] cfg.PROFILERS = ["default_flop_counter"]
...@@ -634,8 +634,8 @@ def _add_rcnn_default_config(_C): ...@@ -634,8 +634,8 @@ def _add_rcnn_default_config(_C):
class GeneralizedRCNNRunner(Detectron2GoRunner): class GeneralizedRCNNRunner(Detectron2GoRunner):
@staticmethod @classmethod
def get_default_cfg(): def get_default_cfg(cls):
_C = super(GeneralizedRCNNRunner, GeneralizedRCNNRunner).get_default_cfg() _C = super(GeneralizedRCNNRunner, GeneralizedRCNNRunner).get_default_cfg()
_add_rcnn_default_config(_C) _add_rcnn_default_config(_C)
return _C return _C
...@@ -29,7 +29,8 @@ class DETRDatasetMapper(DetrDatasetMapper, D2GoDatasetMapper): ...@@ -29,7 +29,8 @@ class DETRDatasetMapper(DetrDatasetMapper, D2GoDatasetMapper):
class DETRRunner(GeneralizedRCNNRunner): class DETRRunner(GeneralizedRCNNRunner):
def get_default_cfg(self): @classmethod
def get_default_cfg(cls):
_C = super().get_default_cfg() _C = super().get_default_cfg()
add_detr_config(_C) add_detr_config(_C)
add_deit_backbone_config(_C) add_deit_backbone_config(_C)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment