Unverified Commit dce89919 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Auto-reformat multiple python folders. (#5325)



* auto-reformat

* lintrunner

---------
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent ab812179
from .graphpred import ApplyGraphpredPipeline
from .nodepred import ApplyNodepredPipeline from .nodepred import ApplyNodepredPipeline
from .nodepred_sample import ApplyNodepredNsPipeline from .nodepred_sample import ApplyNodepredNsPipeline
from .graphpred import ApplyGraphpredPipeline
from copy import deepcopy
from pathlib import Path
from typing import Optional
import ruamel.yaml import ruamel.yaml
import torch import torch
import typer import typer
from copy import deepcopy
from jinja2 import Template from jinja2 import Template
from pathlib import Path
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Optional
from ...utils.factory import ApplyPipelineFactory, PipelineBase, DataFactory, GraphModelFactory from ...utils.factory import (
ApplyPipelineFactory,
DataFactory,
GraphModelFactory,
PipelineBase,
)
from ...utils.yaml_dump import deep_convert_dict, merge_comment from ...utils.yaml_dump import deep_convert_dict, merge_comment
pipeline_comments = { pipeline_comments = {
"batch_size": "Graph batch size", "batch_size": "Graph batch size",
"num_workers": "Number of workers for data loading", "num_workers": "Number of workers for data loading",
"save_path": "Directory to save the inference results" "save_path": "Directory to save the inference results",
} }
class ApplyGraphpredPipelineCfg(BaseModel): class ApplyGraphpredPipelineCfg(BaseModel):
batch_size: int = 32 batch_size: int = 32
num_workers: int = 4 num_workers: int = 4
save_path: str = "apply_results" save_path: str = "apply_results"
@ApplyPipelineFactory.register("graphpred") @ApplyPipelineFactory.register("graphpred")
class ApplyGraphpredPipeline(PipelineBase): class ApplyGraphpredPipeline(PipelineBase):
def __init__(self): def __init__(self):
self.pipeline = { self.pipeline = {"name": "graphpred", "mode": "apply"}
"name": "graphpred",
"mode": "apply"
}
@classmethod @classmethod
def setup_user_cfg_cls(cls): def setup_user_cfg_cls(cls):
from ...utils.enter_config import UserConfig from ...utils.enter_config import UserConfig
class ApplyGraphPredUserConfig(UserConfig): class ApplyGraphPredUserConfig(UserConfig):
data: DataFactory.filter("graphpred").get_pydantic_config() = Field(..., discriminator="name") data: DataFactory.filter("graphpred").get_pydantic_config() = Field(
general_pipeline: ApplyGraphpredPipelineCfg = ApplyGraphpredPipelineCfg() ..., discriminator="name"
)
general_pipeline: ApplyGraphpredPipelineCfg = (
ApplyGraphpredPipelineCfg()
)
cls.user_cfg_cls = ApplyGraphPredUserConfig cls.user_cfg_cls = ApplyGraphPredUserConfig
...@@ -45,9 +54,13 @@ class ApplyGraphpredPipeline(PipelineBase): ...@@ -45,9 +54,13 @@ class ApplyGraphpredPipeline(PipelineBase):
def get_cfg_func(self): def get_cfg_func(self):
def config( def config(
data: DataFactory.filter("graphpred").get_dataset_enum() = typer.Option(None, help="input data name"), data: DataFactory.filter(
cfg: Optional[str] = typer.Option(None, help="output configuration file path"), "graphpred"
cpt: str = typer.Option(..., help="input checkpoint file path") ).get_dataset_enum() = typer.Option(None, help="input data name"),
cfg: Optional[str] = typer.Option(
None, help="output configuration file path"
),
cpt: str = typer.Option(..., help="input checkpoint file path"),
): ):
# Training configuration # Training configuration
train_cfg = torch.load(cpt)["cfg"] train_cfg = torch.load(cpt)["cfg"]
...@@ -57,7 +70,12 @@ class ApplyGraphpredPipeline(PipelineBase): ...@@ -57,7 +70,12 @@ class ApplyGraphpredPipeline(PipelineBase):
else: else:
data = data.name data = data.name
if cfg is None: if cfg is None:
cfg = "_".join(["apply", "graphpred", data, train_cfg["model_name"]]) + ".yaml" cfg = (
"_".join(
["apply", "graphpred", data, train_cfg["model_name"]]
)
+ ".yaml"
)
self.__class__.setup_user_cfg_cls() self.__class__.setup_user_cfg_cls()
generated_cfg = { generated_cfg = {
...@@ -66,23 +84,31 @@ class ApplyGraphpredPipeline(PipelineBase): ...@@ -66,23 +84,31 @@ class ApplyGraphpredPipeline(PipelineBase):
"device": train_cfg["device"], "device": train_cfg["device"],
"data": {"name": data}, "data": {"name": data},
"cpt_path": cpt, "cpt_path": cpt,
"general_pipeline": {"batch_size": train_cfg["general_pipeline"]["eval_batch_size"], "general_pipeline": {
"num_workers": train_cfg["general_pipeline"]["num_workers"]} "batch_size": train_cfg["general_pipeline"][
"eval_batch_size"
],
"num_workers": train_cfg["general_pipeline"]["num_workers"],
},
} }
output_cfg = self.user_cfg_cls(**generated_cfg).dict() output_cfg = self.user_cfg_cls(**generated_cfg).dict()
output_cfg = deep_convert_dict(output_cfg) output_cfg = deep_convert_dict(output_cfg)
# Not applicable for inference # Not applicable for inference
output_cfg['data'].pop('split_ratio') output_cfg["data"].pop("split_ratio")
comment_dict = { comment_dict = {
"device": "Torch device name, e.g., cpu or cuda or cuda:0", "device": "Torch device name, e.g., cpu or cuda or cuda:0",
"cpt_path": "Path to the checkpoint file", "cpt_path": "Path to the checkpoint file",
"general_pipeline": pipeline_comments "general_pipeline": pipeline_comments,
} }
comment_dict = merge_comment(output_cfg, comment_dict) comment_dict = merge_comment(output_cfg, comment_dict)
yaml = ruamel.yaml.YAML() yaml = ruamel.yaml.YAML()
yaml.dump(comment_dict, Path(cfg).open("w")) yaml.dump(comment_dict, Path(cfg).open("w"))
print("Configuration file is generated at {}".format(Path(cfg).absolute())) print(
"Configuration file is generated at {}".format(
Path(cfg).absolute()
)
)
return config return config
...@@ -100,8 +126,12 @@ class ApplyGraphpredPipeline(PipelineBase): ...@@ -100,8 +126,12 @@ class ApplyGraphpredPipeline(PipelineBase):
model_name = train_cfg["model_name"] model_name = train_cfg["model_name"]
model_code = GraphModelFactory.get_source_code(model_name) model_code = GraphModelFactory.get_source_code(model_name)
render_cfg["model_code"] = model_code render_cfg["model_code"] = model_code
render_cfg["model_class_name"] = GraphModelFactory.get_model_class_name(model_name) render_cfg["model_class_name"] = GraphModelFactory.get_model_class_name(
render_cfg.update(DataFactory.get_generated_code_dict(user_cfg_dict["data"]["name"])) model_name
)
render_cfg.update(
DataFactory.get_generated_code_dict(user_cfg_dict["data"]["name"])
)
# Dict for defining cfg in the rendered code # Dict for defining cfg in the rendered code
generated_user_cfg = deepcopy(user_cfg_dict) generated_user_cfg = deepcopy(user_cfg_dict)
......
from copy import deepcopy
from pathlib import Path
from typing import Optional
import ruamel.yaml import ruamel.yaml
import torch import torch
import typer import typer
from copy import deepcopy
from jinja2 import Template from jinja2 import Template
from pathlib import Path
from pydantic import Field from pydantic import Field
from typing import Optional
from ...utils.factory import ApplyPipelineFactory, PipelineBase, DataFactory, NodeModelFactory from ...utils.factory import (
ApplyPipelineFactory,
DataFactory,
NodeModelFactory,
PipelineBase,
)
from ...utils.yaml_dump import deep_convert_dict, merge_comment from ...utils.yaml_dump import deep_convert_dict, merge_comment
@ApplyPipelineFactory.register("nodepred") @ApplyPipelineFactory.register("nodepred")
class ApplyNodepredPipeline(PipelineBase): class ApplyNodepredPipeline(PipelineBase):
def __init__(self): def __init__(self):
self.pipeline = { self.pipeline = {"name": "nodepred", "mode": "apply"}
"name": "nodepred",
"mode": "apply"
}
@classmethod @classmethod
def setup_user_cfg_cls(cls): def setup_user_cfg_cls(cls):
from ...utils.enter_config import UserConfig from ...utils.enter_config import UserConfig
class ApplyNodePredUserConfig(UserConfig): class ApplyNodePredUserConfig(UserConfig):
data: DataFactory.filter("nodepred").get_pydantic_config() = Field(..., discriminator="name") data: DataFactory.filter("nodepred").get_pydantic_config() = Field(
..., discriminator="name"
)
cls.user_cfg_cls = ApplyNodePredUserConfig cls.user_cfg_cls = ApplyNodePredUserConfig
...@@ -34,9 +39,13 @@ class ApplyNodepredPipeline(PipelineBase): ...@@ -34,9 +39,13 @@ class ApplyNodepredPipeline(PipelineBase):
def get_cfg_func(self): def get_cfg_func(self):
def config( def config(
data: DataFactory.filter("nodepred").get_dataset_enum() = typer.Option(None, help="input data name"), data: DataFactory.filter(
cfg: Optional[str] = typer.Option(None, help="output configuration file path"), "nodepred"
cpt: str = typer.Option(..., help="input checkpoint file path") ).get_dataset_enum() = typer.Option(None, help="input data name"),
cfg: Optional[str] = typer.Option(
None, help="output configuration file path"
),
cpt: str = typer.Option(..., help="input checkpoint file path"),
): ):
# Training configuration # Training configuration
train_cfg = torch.load(cpt)["cfg"] train_cfg = torch.load(cpt)["cfg"]
...@@ -46,7 +55,12 @@ class ApplyNodepredPipeline(PipelineBase): ...@@ -46,7 +55,12 @@ class ApplyNodepredPipeline(PipelineBase):
else: else:
data = data.name data = data.name
if cfg is None: if cfg is None:
cfg = "_".join(["apply", "nodepred", data, train_cfg["model_name"]]) + ".yaml" cfg = (
"_".join(
["apply", "nodepred", data, train_cfg["model_name"]]
)
+ ".yaml"
)
self.__class__.setup_user_cfg_cls() self.__class__.setup_user_cfg_cls()
generated_cfg = { generated_cfg = {
...@@ -55,22 +69,28 @@ class ApplyNodepredPipeline(PipelineBase): ...@@ -55,22 +69,28 @@ class ApplyNodepredPipeline(PipelineBase):
"device": train_cfg["device"], "device": train_cfg["device"],
"data": {"name": data}, "data": {"name": data},
"cpt_path": cpt, "cpt_path": cpt,
"general_pipeline": {"save_path": "apply_results"} "general_pipeline": {"save_path": "apply_results"},
} }
output_cfg = self.user_cfg_cls(**generated_cfg).dict() output_cfg = self.user_cfg_cls(**generated_cfg).dict()
output_cfg = deep_convert_dict(output_cfg) output_cfg = deep_convert_dict(output_cfg)
# Not applicable for inference # Not applicable for inference
output_cfg['data'].pop('split_ratio') output_cfg["data"].pop("split_ratio")
comment_dict = { comment_dict = {
"device": "Torch device name, e.g., cpu or cuda or cuda:0", "device": "Torch device name, e.g., cpu or cuda or cuda:0",
"cpt_path": "Path to the checkpoint file", "cpt_path": "Path to the checkpoint file",
"general_pipeline": {"save_path": "Directory to save the inference results"} "general_pipeline": {
"save_path": "Directory to save the inference results"
},
} }
comment_dict = merge_comment(output_cfg, comment_dict) comment_dict = merge_comment(output_cfg, comment_dict)
yaml = ruamel.yaml.YAML() yaml = ruamel.yaml.YAML()
yaml.dump(comment_dict, Path(cfg).open("w")) yaml.dump(comment_dict, Path(cfg).open("w"))
print("Configuration file is generated at {}".format(Path(cfg).absolute())) print(
"Configuration file is generated at {}".format(
Path(cfg).absolute()
)
)
return config return config
...@@ -88,8 +108,12 @@ class ApplyNodepredPipeline(PipelineBase): ...@@ -88,8 +108,12 @@ class ApplyNodepredPipeline(PipelineBase):
model_name = train_cfg["model_name"] model_name = train_cfg["model_name"]
model_code = NodeModelFactory.get_source_code(model_name) model_code = NodeModelFactory.get_source_code(model_name)
render_cfg["model_code"] = model_code render_cfg["model_code"] = model_code
render_cfg["model_class_name"] = NodeModelFactory.get_model_class_name(model_name) render_cfg["model_class_name"] = NodeModelFactory.get_model_class_name(
render_cfg.update(DataFactory.get_generated_code_dict(user_cfg_dict["data"]["name"])) model_name
)
render_cfg.update(
DataFactory.get_generated_code_dict(user_cfg_dict["data"]["name"])
)
# Dict for defining cfg in the rendered code # Dict for defining cfg in the rendered code
generated_user_cfg = deepcopy(user_cfg_dict) generated_user_cfg = deepcopy(user_cfg_dict)
......
from copy import deepcopy
from pathlib import Path
from typing import Optional
import ruamel.yaml import ruamel.yaml
import typer
import torch import torch
import typer
from copy import deepcopy
from jinja2 import Template from jinja2 import Template
from pathlib import Path
from pydantic import Field from pydantic import Field
from typing import Optional
from ...utils.factory import ApplyPipelineFactory, PipelineBase, DataFactory, NodeModelFactory from ...utils.factory import (
ApplyPipelineFactory,
DataFactory,
NodeModelFactory,
PipelineBase,
)
from ...utils.yaml_dump import deep_convert_dict, merge_comment from ...utils.yaml_dump import deep_convert_dict, merge_comment
@ApplyPipelineFactory.register("nodepred-ns") @ApplyPipelineFactory.register("nodepred-ns")
class ApplyNodepredNsPipeline(PipelineBase): class ApplyNodepredNsPipeline(PipelineBase):
def __init__(self): def __init__(self):
self.pipeline = { self.pipeline = {"name": "nodepred-ns", "mode": "apply"}
"name": "nodepred-ns",
"mode": "apply"
}
@classmethod @classmethod
def setup_user_cfg_cls(cls): def setup_user_cfg_cls(cls):
from ...utils.enter_config import UserConfig from ...utils.enter_config import UserConfig
class ApplyNodePredUserConfig(UserConfig): class ApplyNodePredUserConfig(UserConfig):
data: DataFactory.filter("nodepred-ns").get_pydantic_config() = Field(..., discriminator="name") data: DataFactory.filter(
"nodepred-ns"
).get_pydantic_config() = Field(..., discriminator="name")
cls.user_cfg_cls = ApplyNodePredUserConfig cls.user_cfg_cls = ApplyNodePredUserConfig
...@@ -34,9 +39,13 @@ class ApplyNodepredNsPipeline(PipelineBase): ...@@ -34,9 +39,13 @@ class ApplyNodepredNsPipeline(PipelineBase):
def get_cfg_func(self): def get_cfg_func(self):
def config( def config(
data: DataFactory.filter("nodepred-ns").get_dataset_enum() = typer.Option(None, help="input data name"), data: DataFactory.filter(
cfg: Optional[str] = typer.Option(None, help="output configuration file path"), "nodepred-ns"
cpt: str = typer.Option(..., help="input checkpoint file path") ).get_dataset_enum() = typer.Option(None, help="input data name"),
cfg: Optional[str] = typer.Option(
None, help="output configuration file path"
),
cpt: str = typer.Option(..., help="input checkpoint file path"),
): ):
# Training configuration # Training configuration
train_cfg = torch.load(cpt)["cfg"] train_cfg = torch.load(cpt)["cfg"]
...@@ -46,7 +55,12 @@ class ApplyNodepredNsPipeline(PipelineBase): ...@@ -46,7 +55,12 @@ class ApplyNodepredNsPipeline(PipelineBase):
else: else:
data = data.name data = data.name
if cfg is None: if cfg is None:
cfg = "_".join(["apply", "nodepred-ns", data, train_cfg["model_name"]]) + ".yaml" cfg = (
"_".join(
["apply", "nodepred-ns", data, train_cfg["model_name"]]
)
+ ".yaml"
)
self.__class__.setup_user_cfg_cls() self.__class__.setup_user_cfg_cls()
generated_cfg = { generated_cfg = {
...@@ -55,22 +69,28 @@ class ApplyNodepredNsPipeline(PipelineBase): ...@@ -55,22 +69,28 @@ class ApplyNodepredNsPipeline(PipelineBase):
"device": train_cfg["device"], "device": train_cfg["device"],
"data": {"name": data}, "data": {"name": data},
"cpt_path": cpt, "cpt_path": cpt,
"general_pipeline": {"save_path": "apply_results"} "general_pipeline": {"save_path": "apply_results"},
} }
output_cfg = self.user_cfg_cls(**generated_cfg).dict() output_cfg = self.user_cfg_cls(**generated_cfg).dict()
output_cfg = deep_convert_dict(output_cfg) output_cfg = deep_convert_dict(output_cfg)
# Not applicable for inference # Not applicable for inference
output_cfg['data'].pop('split_ratio') output_cfg["data"].pop("split_ratio")
comment_dict = { comment_dict = {
"device": "Torch device name, e.g., cpu or cuda or cuda:0", "device": "Torch device name, e.g., cpu or cuda or cuda:0",
"cpt_path": "Path to the checkpoint file", "cpt_path": "Path to the checkpoint file",
"general_pipeline": {"save_path": "Directory to save the inference results"} "general_pipeline": {
"save_path": "Directory to save the inference results"
},
} }
comment_dict = merge_comment(output_cfg, comment_dict) comment_dict = merge_comment(output_cfg, comment_dict)
yaml = ruamel.yaml.YAML() yaml = ruamel.yaml.YAML()
yaml.dump(comment_dict, Path(cfg).open("w")) yaml.dump(comment_dict, Path(cfg).open("w"))
print("Configuration file is generated at {}".format(Path(cfg).absolute())) print(
"Configuration file is generated at {}".format(
Path(cfg).absolute()
)
)
return config return config
...@@ -88,8 +108,12 @@ class ApplyNodepredNsPipeline(PipelineBase): ...@@ -88,8 +108,12 @@ class ApplyNodepredNsPipeline(PipelineBase):
model_name = train_cfg["model_name"] model_name = train_cfg["model_name"]
model_code = NodeModelFactory.get_source_code(model_name) model_code = NodeModelFactory.get_source_code(model_name)
render_cfg["model_code"] = model_code render_cfg["model_code"] = model_code
render_cfg["model_class_name"] = NodeModelFactory.get_model_class_name(model_name) render_cfg["model_class_name"] = NodeModelFactory.get_model_class_name(
render_cfg.update(DataFactory.get_generated_code_dict(user_cfg_dict["data"]["name"])) model_name
)
render_cfg.update(
DataFactory.get_generated_code_dict(user_cfg_dict["data"]["name"])
)
# Dict for defining cfg in the rendered code # Dict for defining cfg in the rendered code
generated_user_cfg = deepcopy(user_cfg_dict) generated_user_cfg = deepcopy(user_cfg_dict)
......
from ..utils.factory import ApplyPipelineFactory from pathlib import Path
import autopep8 import autopep8
import isort import isort
import typer import typer
import yaml import yaml
from pathlib import Path from ..utils.factory import ApplyPipelineFactory
def apply( def apply(cfg: str = typer.Option(..., help="config yaml file name")):
cfg: str = typer.Option(..., help="config yaml file name")
):
user_cfg = yaml.safe_load(Path(cfg).open("r")) user_cfg = yaml.safe_load(Path(cfg).open("r"))
pipeline_name = user_cfg["pipeline_name"] pipeline_name = user_cfg["pipeline_name"]
output_file_content = ApplyPipelineFactory.registry[pipeline_name].gen_script(user_cfg) output_file_content = ApplyPipelineFactory.registry[
pipeline_name
].gen_script(user_cfg)
f_code = autopep8.fix_code(output_file_content, options={'aggressive': 1}) f_code = autopep8.fix_code(output_file_content, options={"aggressive": 1})
f_code = isort.code(f_code) f_code = isort.code(f_code)
code = compile(f_code, 'dglgo_tmp.py', 'exec') code = compile(f_code, "dglgo_tmp.py", "exec")
exec(code, {'__name__': '__main__'}) exec(code, {"__name__": "__main__"})
import typer import typer
from ..pipeline import * from ..pipeline import *
from ..model import * from ..model import *
from .apply_cli import apply
from .config_apply_cli import config_apply_app
from .config_cli import config_app from .config_cli import config_app
from .train_cli import train
from .export_cli import export from .export_cli import export
from .recipe_cli import recipe_app from .recipe_cli import recipe_app
from .config_apply_cli import config_apply_app from .train_cli import train
from .apply_cli import apply
no_args_is_help = False no_args_is_help = False
app = typer.Typer(no_args_is_help=True, add_completion=False) app = typer.Typer(no_args_is_help=True, add_completion=False)
app.add_typer(config_app, name="configure", no_args_is_help=no_args_is_help) app.add_typer(config_app, name="configure", no_args_is_help=no_args_is_help)
app.add_typer(recipe_app, name="recipe", no_args_is_help=True) app.add_typer(recipe_app, name="recipe", no_args_is_help=True)
app.command(help="Launch training", no_args_is_help=no_args_is_help)(train) app.command(help="Launch training", no_args_is_help=no_args_is_help)(train)
app.command(help="Export a runnable python script", no_args_is_help=no_args_is_help)(export) app.command(
app.add_typer(config_apply_app, name="configure-apply", no_args_is_help=no_args_is_help) help="Export a runnable python script", no_args_is_help=no_args_is_help
)(export)
app.add_typer(
config_apply_app, name="configure-apply", no_args_is_help=no_args_is_help
)
app.command(help="Launch inference", no_args_is_help=no_args_is_help)(apply) app.command(help="Launch inference", no_args_is_help=no_args_is_help)(apply)
def main(): def main():
app() app()
if __name__ == "__main__": if __name__ == "__main__":
app() app()
from ..apply_pipeline import * from ..apply_pipeline import *
from ..utils.factory import ApplyPipelineFactory
import typer import typer
config_apply_app = typer.Typer(help="Generate a configuration file for inference") from ..utils.factory import ApplyPipelineFactory
config_apply_app = typer.Typer(
help="Generate a configuration file for inference"
)
for key, pipeline in ApplyPipelineFactory.registry.items(): for key, pipeline in ApplyPipelineFactory.registry.items():
config_apply_app.command(key, help=pipeline.get_description())(pipeline.get_cfg_func()) config_apply_app.command(key, help=pipeline.get_description())(
pipeline.get_cfg_func()
)
from ..pipeline import * from ..pipeline import *
from ..utils.factory import ModelFactory, PipelineFactory
import typer
from enum import Enum
import typing import typing
import yaml from enum import Enum
from pathlib import Path from pathlib import Path
import typer
import yaml
from ..utils.factory import ModelFactory, PipelineFactory
config_app = typer.Typer(help="Generate a configuration file") config_app = typer.Typer(help="Generate a configuration file")
for key, pipeline in PipelineFactory.registry.items(): for key, pipeline in PipelineFactory.registry.items():
config_app.command(key, help=pipeline.get_description())(pipeline.get_cfg_func()) config_app.command(key, help=pipeline.get_description())(
pipeline.get_cfg_func()
)
if __name__ == "__main__": if __name__ == "__main__":
config_app() config_app()
from ..utils.factory import ModelFactory, PipelineFactory, ApplyPipelineFactory
import typer
from enum import Enum
import typing import typing
import yaml from enum import Enum
from pathlib import Path from pathlib import Path
import isort
import autopep8 import autopep8
import isort
import typer
import yaml
from ..utils.factory import ApplyPipelineFactory, ModelFactory, PipelineFactory
def export( def export(
cfg: str = typer.Option("cfg.yaml", help="config yaml file name"), cfg: str = typer.Option("cfg.yaml", help="config yaml file name"),
output: str = typer.Option("script.py", help="output python file name") output: str = typer.Option("script.py", help="output python file name"),
): ):
user_cfg = yaml.safe_load(Path(cfg).open("r")) user_cfg = yaml.safe_load(Path(cfg).open("r"))
pipeline_name = user_cfg["pipeline_name"] pipeline_name = user_cfg["pipeline_name"]
pipeline_mode = user_cfg["pipeline_mode"] pipeline_mode = user_cfg["pipeline_mode"]
if pipeline_mode == 'train': if pipeline_mode == "train":
output_file_content = PipelineFactory.registry[pipeline_name].gen_script(user_cfg) output_file_content = PipelineFactory.registry[
pipeline_name
].gen_script(user_cfg)
else: else:
output_file_content = ApplyPipelineFactory.registry[pipeline_name].gen_script(user_cfg) output_file_content = ApplyPipelineFactory.registry[
pipeline_name
].gen_script(user_cfg)
f_code = autopep8.fix_code(output_file_content, options={'aggressive': 1}) f_code = autopep8.fix_code(output_file_content, options={"aggressive": 1})
f_code = isort.code(f_code) f_code = isort.code(f_code)
with open(output, "w") as f: with open(output, "w") as f:
f.write(f_code) f.write(f_code)
print("The python script is generated at {}, based on config file {}".format(Path(output).absolute(), Path(cfg).absolute())) print(
"The python script is generated at {}, based on config file {}".format(
Path(output).absolute(), Path(cfg).absolute()
)
)
if __name__ == "__main__": if __name__ == "__main__":
export_app = typer.Typer() export_app = typer.Typer()
......
import os
import shutil
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import typer import typer
import os
import shutil
import yaml import yaml
def list_recipes(): def list_recipes():
file_current_dir = Path(__file__).resolve().parent file_current_dir = Path(__file__).resolve().parent
recipe_dir = file_current_dir.parent.parent / "recipes" recipe_dir = file_current_dir.parent.parent / "recipes"
file_list = list(recipe_dir.glob("*.yaml")) file_list = list(recipe_dir.glob("*.yaml"))
header = "| {:<30} | {:<18} | {:<20} |".format("Filename", "Pipeline", "Dataset") header = "| {:<30} | {:<18} | {:<20} |".format(
typer.echo("="*len(header)) "Filename", "Pipeline", "Dataset"
)
typer.echo("=" * len(header))
typer.echo(header) typer.echo(header)
typer.echo("="*len(header)) typer.echo("=" * len(header))
output_list = [] output_list = []
for file in file_list: for file in file_list:
cfg = yaml.safe_load(Path(file).open("r")) cfg = yaml.safe_load(Path(file).open("r"))
output_list.append({ output_list.append(
{
"file_name": file.name, "file_name": file.name,
"pipeline_name": cfg["pipeline_name"], "pipeline_name": cfg["pipeline_name"],
"dataset_name": cfg["data"]["name"] "dataset_name": cfg["data"]["name"],
}) }
)
# sort by pipeline, if same sort by dataset, if same sort by file name # sort by pipeline, if same sort by dataset, if same sort by file name
output_list.sort(key=lambda f: (f["pipeline_name"], f["dataset_name"], f["file_name"])) output_list.sort(
key=lambda f: (f["pipeline_name"], f["dataset_name"], f["file_name"])
)
for f in output_list: for f in output_list:
typer.echo("| {:<30} | {:<18} | {:<20} |".format(f["file_name"], f["pipeline_name"], f["dataset_name"])) typer.echo(
typer.echo("="*len(header)) "| {:<30} | {:<18} | {:<20} |".format(
f["file_name"], f["pipeline_name"], f["dataset_name"]
)
)
typer.echo("=" * len(header))
def get_recipe(recipe_name: Optional[str] = typer.Argument(None, help="The recipe filename to get, e.q. nodepred_citeseer_gcn.yaml")): def get_recipe(
recipe_name: Optional[str] = typer.Argument(
None, help="The recipe filename to get, e.q. nodepred_citeseer_gcn.yaml"
)
):
if recipe_name is None: if recipe_name is None:
typer.echo("Usage: dgl recipe get [RECIPE_NAME] \n") typer.echo("Usage: dgl recipe get [RECIPE_NAME] \n")
typer.echo(" Copy the recipe to current directory \n") typer.echo(" Copy the recipe to current directory \n")
typer.echo(" Arguments:") typer.echo(" Arguments:")
typer.echo(" [RECIPE_NAME] The recipe filename to get, e.q. nodepred_citeseer_gcn.yaml\n") typer.echo(
" [RECIPE_NAME] The recipe filename to get, e.q. nodepred_citeseer_gcn.yaml\n"
)
typer.echo("Here are all avaliable recipe filename") typer.echo("Here are all avaliable recipe filename")
list_recipes() list_recipes()
else: else:
...@@ -41,12 +60,20 @@ def get_recipe(recipe_name: Optional[str] = typer.Argument(None, help="The recip ...@@ -41,12 +60,20 @@ def get_recipe(recipe_name: Optional[str] = typer.Argument(None, help="The recip
current_dir = Path(os.getcwd()) current_dir = Path(os.getcwd())
recipe_path = recipe_dir / recipe_name recipe_path = recipe_dir / recipe_name
shutil.copy(recipe_path, current_dir) shutil.copy(recipe_path, current_dir)
print("Recipe {} is copied to {}".format(recipe_path.absolute(), current_dir.absolute())) print(
"Recipe {} is copied to {}".format(
recipe_path.absolute(), current_dir.absolute()
)
)
recipe_app = typer.Typer(help="Get example recipes") recipe_app = typer.Typer(help="Get example recipes")
recipe_app.command(name="list", help="List all available example recipes")(list_recipes) recipe_app.command(name="list", help="List all available example recipes")(
recipe_app.command(name="get", help="Copy the recipe to current directory")(get_recipe) list_recipes
)
recipe_app.command(name="get", help="Copy the recipe to current directory")(
get_recipe
)
if __name__ == "__main__": if __name__ == "__main__":
recipe_app() recipe_app()
from ..utils.factory import ModelFactory, PipelineFactory
import typer
from enum import Enum
import typing import typing
import yaml from enum import Enum
from pathlib import Path from pathlib import Path
import isort
import autopep8 import autopep8
import isort
import typer
import yaml
from ..utils.factory import ModelFactory, PipelineFactory
def train( def train(
cfg: str = typer.Option("cfg.yaml", help="config yaml file name"), cfg: str = typer.Option("cfg.yaml", help="config yaml file name"),
): ):
user_cfg = yaml.safe_load(Path(cfg).open("r")) user_cfg = yaml.safe_load(Path(cfg).open("r"))
pipeline_name = user_cfg["pipeline_name"] pipeline_name = user_cfg["pipeline_name"]
output_file_content = PipelineFactory.registry[pipeline_name].gen_script(user_cfg) output_file_content = PipelineFactory.registry[pipeline_name].gen_script(
user_cfg
)
f_code = autopep8.fix_code(output_file_content, options={'aggressive': 1}) f_code = autopep8.fix_code(output_file_content, options={"aggressive": 1})
f_code = isort.code(f_code) f_code = isort.code(f_code)
code = compile(f_code, 'dglgo_tmp.py', 'exec') code = compile(f_code, "dglgo_tmp.py", "exec")
exec(code, {'__name__': '__main__'}) exec(code, {"__name__": "__main__"})
if __name__ == "__main__": if __name__ == "__main__":
train_app = typer.Typer() train_app = typer.Typer()
......
from ...utils.factory import EdgeModelFactory from ...utils.factory import EdgeModelFactory
from .ele import ElementWiseProductPredictor
from .bilinear import BilinearPredictor from .bilinear import BilinearPredictor
from .ele import ElementWiseProductPredictor
EdgeModelFactory.register("ele")(ElementWiseProductPredictor) EdgeModelFactory.register("ele")(ElementWiseProductPredictor)
EdgeModelFactory.register("bilinear")(BilinearPredictor) EdgeModelFactory.register("bilinear")(BilinearPredictor)
...@@ -4,11 +4,13 @@ import torch.nn.functional as F ...@@ -4,11 +4,13 @@ import torch.nn.functional as F
class BilinearPredictor(nn.Module): class BilinearPredictor(nn.Module):
def __init__(self, def __init__(
self,
data_info: dict, data_info: dict,
hidden_size: int = 32, hidden_size: int = 32,
num_layers: int = 1, num_layers: int = 1,
bias: bool = True): bias: bool = True,
):
"""Bilinear product model for edge scores """Bilinear product model for edge scores
Parameters Parameters
...@@ -26,7 +28,7 @@ class BilinearPredictor(nn.Module): ...@@ -26,7 +28,7 @@ class BilinearPredictor(nn.Module):
in_size, out_size = data_info["in_size"], data_info["out_size"] in_size, out_size = data_info["in_size"], data_info["out_size"]
self.bilinear = nn.Bilinear(in_size, in_size, hidden_size, bias=bias) self.bilinear = nn.Bilinear(in_size, in_size, hidden_size, bias=bias)
lins_list = [] lins_list = []
for _ in range(num_layers-2): for _ in range(num_layers - 2):
lins_list.append(nn.Linear(hidden_size, hidden_size, bias=bias)) lins_list.append(nn.Linear(hidden_size, hidden_size, bias=bias))
lins_list.append(nn.ReLU()) lins_list.append(nn.ReLU())
lins_list.append(nn.Linear(hidden_size, out_size, bias=bias)) lins_list.append(nn.Linear(hidden_size, out_size, bias=bias))
......
...@@ -2,16 +2,19 @@ import torch ...@@ -2,16 +2,19 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
class DotPredictor(nn.Module): class DotPredictor(nn.Module):
def __init__(self, def __init__(
self,
in_size: int = -1, in_size: int = -1,
out_size: int = 1, out_size: int = 1,
hidden_size: int = 256, hidden_size: int = 256,
num_layers: int = 3, num_layers: int = 3,
bias: bool = False): bias: bool = False,
):
super(DotPredictor, self).__init__() super(DotPredictor, self).__init__()
lins_list = [] lins_list = []
for _ in range(num_layers-2): for _ in range(num_layers - 2):
lins_list.append(nn.Linear(in_size, hidden_size, bias=bias)) lins_list.append(nn.Linear(in_size, hidden_size, bias=bias))
lins_list.append(nn.ReLU()) lins_list.append(nn.ReLU())
lins_list.append(nn.Linear(hidden_size, out_size, bias=bias)) lins_list.append(nn.Linear(hidden_size, out_size, bias=bias))
......
...@@ -4,11 +4,13 @@ import torch.nn.functional as F ...@@ -4,11 +4,13 @@ import torch.nn.functional as F
class ElementWiseProductPredictor(nn.Module): class ElementWiseProductPredictor(nn.Module):
def __init__(self, def __init__(
self,
data_info: dict, data_info: dict,
hidden_size: int = 64, hidden_size: int = 64,
num_layers: int = 2, num_layers: int = 2,
bias: bool = True): bias: bool = True,
):
"""Elementwise product model for edge scores """Elementwise product model for edge scores
Parameters Parameters
......
...@@ -2,12 +2,12 @@ import dgl ...@@ -2,12 +2,12 @@ import dgl
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.nn import GINEConv, AvgPooling, SumPooling from dgl.nn import AvgPooling, GINEConv, SumPooling
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, def __init__(self, feat_size: int):
feat_size: int):
"""Multilayer Perceptron (MLP)""" """Multilayer Perceptron (MLP)"""
super(MLP, self).__init__() super(MLP, self).__init__()
self.mlp = nn.Sequential( self.mlp = nn.Sequential(
...@@ -15,19 +15,22 @@ class MLP(nn.Module): ...@@ -15,19 +15,22 @@ class MLP(nn.Module):
nn.BatchNorm1d(2 * feat_size), nn.BatchNorm1d(2 * feat_size),
nn.ReLU(), nn.ReLU(),
nn.Linear(2 * feat_size, feat_size), nn.Linear(2 * feat_size, feat_size),
nn.BatchNorm1d(feat_size) nn.BatchNorm1d(feat_size),
) )
def forward(self, h): def forward(self, h):
return self.mlp(h) return self.mlp(h)
class OGBGGIN(nn.Module): class OGBGGIN(nn.Module):
def __init__(self, def __init__(
self,
data_info: dict, data_info: dict,
embed_size: int = 300, embed_size: int = 300,
num_layers: int = 5, num_layers: int = 5,
dropout: float = 0.5, dropout: float = 0.5,
virtual_node : bool = False): virtual_node: bool = False,
):
"""Graph Isomorphism Network (GIN) variant introduced in baselines """Graph Isomorphism Network (GIN) variant introduced in baselines
for OGB graph property prediction datasets for OGB graph property prediction datasets
...@@ -50,21 +53,30 @@ class OGBGGIN(nn.Module): ...@@ -50,21 +53,30 @@ class OGBGGIN(nn.Module):
self.num_layers = num_layers self.num_layers = num_layers
self.virtual_node = virtual_node self.virtual_node = virtual_node
if data_info['name'] in ['ogbg-molhiv', 'ogbg-molpcba']: if data_info["name"] in ["ogbg-molhiv", "ogbg-molpcba"]:
self.node_encoder = AtomEncoder(embed_size) self.node_encoder = AtomEncoder(embed_size)
self.edge_encoders = nn.ModuleList([ self.edge_encoders = nn.ModuleList(
BondEncoder(embed_size) for _ in range(num_layers)]) [BondEncoder(embed_size) for _ in range(num_layers)]
)
else: else:
# Handle other datasets # Handle other datasets
self.node_encoder = nn.Linear(data_info['node_feat_size'], embed_size) self.node_encoder = nn.Linear(
self.edge_encoders = nn.ModuleList([nn.Linear(data_info['edge_feat_size'], embed_size) data_info["node_feat_size"], embed_size
for _ in range(num_layers)]) )
self.edge_encoders = nn.ModuleList(
[
nn.Linear(data_info["edge_feat_size"], embed_size)
for _ in range(num_layers)
]
)
self.conv_layers = nn.ModuleList([GINEConv(MLP(embed_size)) for _ in range(num_layers)]) self.conv_layers = nn.ModuleList(
[GINEConv(MLP(embed_size)) for _ in range(num_layers)]
)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.pool = AvgPooling() self.pool = AvgPooling()
self.pred = nn.Linear(embed_size, data_info['out_size']) self.pred = nn.Linear(embed_size, data_info["out_size"])
if virtual_node: if virtual_node:
self.virtual_emb = nn.Embedding(1, embed_size) self.virtual_emb = nn.Embedding(1, embed_size)
......
from typing import List from typing import List
import dgl.function as fn import dgl.function as fn
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.nn import SumPooling, AvgPooling from dgl.nn import AvgPooling, SumPooling
from ogb.graphproppred.mol_encoder import AtomEncoder from ogb.graphproppred.mol_encoder import AtomEncoder
def aggregate_mean(h): def aggregate_mean(h):
"""mean aggregation""" """mean aggregation"""
return torch.mean(h, dim=1) return torch.mean(h, dim=1)
def aggregate_max(h): def aggregate_max(h):
"""max aggregation""" """max aggregation"""
return torch.max(h, dim=1)[0] return torch.max(h, dim=1)[0]
def aggregate_min(h): def aggregate_min(h):
"""min aggregation""" """min aggregation"""
return torch.min(h, dim=1)[0] return torch.min(h, dim=1)[0]
def aggregate_sum(h): def aggregate_sum(h):
"""sum aggregation""" """sum aggregation"""
return torch.sum(h, dim=1) return torch.sum(h, dim=1)
def aggregate_var(h): def aggregate_var(h):
"""variance aggregation""" """variance aggregation"""
h_mean_squares = torch.mean(h * h, dim=1) h_mean_squares = torch.mean(h * h, dim=1)
...@@ -30,47 +36,66 @@ def aggregate_var(h): ...@@ -30,47 +36,66 @@ def aggregate_var(h):
var = torch.relu(h_mean_squares - h_mean * h_mean) var = torch.relu(h_mean_squares - h_mean * h_mean)
return var return var
def aggregate_std(h): def aggregate_std(h):
"""standard deviation aggregation""" """standard deviation aggregation"""
return torch.sqrt(aggregate_var(h) + 1e-5) return torch.sqrt(aggregate_var(h) + 1e-5)
AGGREGATORS = {'mean': aggregate_mean, 'sum': aggregate_sum, 'max': aggregate_max,
'min': aggregate_min, 'std': aggregate_std, 'var': aggregate_var} AGGREGATORS = {
"mean": aggregate_mean,
"sum": aggregate_sum,
"max": aggregate_max,
"min": aggregate_min,
"std": aggregate_std,
"var": aggregate_var,
}
def scale_identity(h, D, delta): def scale_identity(h, D, delta):
"""identity scaling (no scaling operation)""" """identity scaling (no scaling operation)"""
return h return h
def scale_amplification(h, D, delta): def scale_amplification(h, D, delta):
"""amplification scaling""" """amplification scaling"""
return h * (np.log(D + 1) / delta) return h * (np.log(D + 1) / delta)
def scale_attenuation(h, D, delta): def scale_attenuation(h, D, delta):
"""attenuation scaling""" """attenuation scaling"""
return h * (delta / np.log(D + 1)) return h * (delta / np.log(D + 1))
SCALERS = { SCALERS = {
'identity': scale_identity, "identity": scale_identity,
'amplification': scale_amplification, "amplification": scale_amplification,
'attenuation': scale_attenuation "attenuation": scale_attenuation,
} }
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, def __init__(
self,
in_feat_size: int, in_feat_size: int,
out_feat_size: int, out_feat_size: int,
num_layers: int=3, num_layers: int = 3,
decreasing_hidden_size=False): decreasing_hidden_size=False,
):
"""Multilayer Perceptron (MLP)""" """Multilayer Perceptron (MLP)"""
super(MLP, self).__init__() super(MLP, self).__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
if decreasing_hidden_size: if decreasing_hidden_size:
for i in range(num_layers - 1): for i in range(num_layers - 1):
self.layers.append(nn.Linear(in_feat_size // 2 ** i, self.layers.append(
in_feat_size // 2 ** (i + 1))) nn.Linear(
self.layers.append(nn.Linear(in_feat_size // 2 ** (num_layers - 1), in_feat_size // 2**i, in_feat_size // 2 ** (i + 1)
out_feat_size)) )
)
self.layers.append(
nn.Linear(in_feat_size // 2 ** (num_layers - 1), out_feat_size)
)
else: else:
self.layers.append(nn.Linear(in_feat_size, out_feat_size)) self.layers.append(nn.Linear(in_feat_size, out_feat_size))
for _ in range(num_layers - 1): for _ in range(num_layers - 1):
...@@ -84,9 +109,12 @@ class MLP(nn.Module): ...@@ -84,9 +109,12 @@ class MLP(nn.Module):
h = F.relu(h) h = F.relu(h)
return h return h
class SimplePNAConv(nn.Module): class SimplePNAConv(nn.Module):
r"""A simplified PNAConv variant used in OGB submissions""" r"""A simplified PNAConv variant used in OGB submissions"""
def __init__(self,
def __init__(
self,
feat_size: int, feat_size: int,
aggregators: List[str], aggregators: List[str],
scalers: List[str], scalers: List[str],
...@@ -94,14 +122,18 @@ class SimplePNAConv(nn.Module): ...@@ -94,14 +122,18 @@ class SimplePNAConv(nn.Module):
dropout: float, dropout: float,
batch_norm: bool, batch_norm: bool,
residual: bool, residual: bool,
num_mlp_layers: int): num_mlp_layers: int,
):
super(SimplePNAConv, self).__init__() super(SimplePNAConv, self).__init__()
self.aggregators = [AGGREGATORS[aggr] for aggr in aggregators] self.aggregators = [AGGREGATORS[aggr] for aggr in aggregators]
self.scalers = [SCALERS[scale] for scale in scalers] self.scalers = [SCALERS[scale] for scale in scalers]
self.delta = delta self.delta = delta
self.mlp = MLP(in_feat_size=(len(aggregators) * len(scalers)) * feat_size, self.mlp = MLP(
out_feat_size=feat_size, num_layers=num_mlp_layers) in_feat_size=(len(aggregators) * len(scalers)) * feat_size,
out_feat_size=feat_size,
num_layers=num_mlp_layers,
)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.residual = residual self.residual = residual
...@@ -111,17 +143,19 @@ class SimplePNAConv(nn.Module): ...@@ -111,17 +143,19 @@ class SimplePNAConv(nn.Module):
self.bn = None self.bn = None
def reduce(self, nodes): def reduce(self, nodes):
h = nodes.mailbox['m'] h = nodes.mailbox["m"]
D = h.shape[-2] D = h.shape[-2]
h = torch.cat([aggregate(h) for aggregate in self.aggregators], dim=1) h = torch.cat([aggregate(h) for aggregate in self.aggregators], dim=1)
h = torch.cat([scale(h, D=D, delta=self.delta) for scale in self.scalers], dim=1) h = torch.cat(
return {'h': h} [scale(h, D=D, delta=self.delta) for scale in self.scalers], dim=1
)
return {"h": h}
def forward(self, g, h): def forward(self, g, h):
with g.local_scope(): with g.local_scope():
g.ndata['h'] = h g.ndata["h"] = h
g.update_all(fn.copy_u('h', 'm'), self.reduce) g.update_all(fn.copy_u("h", "m"), self.reduce)
h_new = g.ndata['h'] h_new = g.ndata["h"]
h_new = self.mlp(h_new) h_new = self.mlp(h_new)
if self.bn is not None: if self.bn is not None:
...@@ -134,18 +168,21 @@ class SimplePNAConv(nn.Module): ...@@ -134,18 +168,21 @@ class SimplePNAConv(nn.Module):
return h_new return h_new
class PNA(nn.Module): class PNA(nn.Module):
def __init__(self, def __init__(
self,
data_info: dict, data_info: dict,
embed_size: int = 80, embed_size: int = 80,
aggregators: str = 'mean max min std', aggregators: str = "mean max min std",
scalers: str = 'identity amplification attenuation', scalers: str = "identity amplification attenuation",
dropout: float = 0.3, dropout: float = 0.3,
batch_norm: bool = True, batch_norm: bool = True,
residual: bool = True, residual: bool = True,
num_mlp_layers: int = 1, num_mlp_layers: int = 1,
num_layers: int = 4, num_layers: int = 4,
readout: str = 'mean'): readout: str = "mean",
):
"""Principal Neighbourhood Aggregation """Principal Neighbourhood Aggregation
Parameters Parameters
...@@ -182,44 +219,62 @@ class PNA(nn.Module): ...@@ -182,44 +219,62 @@ class PNA(nn.Module):
self.readout = readout self.readout = readout
if aggregators is None: if aggregators is None:
aggregators = ['mean', 'max', 'min', 'std'] aggregators = ["mean", "max", "min", "std"]
else: else:
aggregators = [agg.strip() for agg in aggregators.split(' ')] aggregators = [agg.strip() for agg in aggregators.split(" ")]
assert set(aggregators).issubset({'mean', 'max', 'min', 'std', 'sum'}), \ assert set(aggregators).issubset(
"Expect aggregators to be a subset of ['mean', 'max', 'min', 'std', 'sum'], \ {"mean", "max", "min", "std", "sum"}
got {}".format(aggregators) ), "Expect aggregators to be a subset of ['mean', 'max', 'min', 'std', 'sum'], \
got {}".format(
aggregators
)
if scalers is None: if scalers is None:
scalers = ['identity', 'amplification', 'attenuation'] scalers = ["identity", "amplification", "attenuation"]
else: else:
scalers = [scl.strip() for scl in scalers.split(' ')] scalers = [scl.strip() for scl in scalers.split(" ")]
assert set(scalers).issubset({'identity', 'amplification', 'attenuation'}), \ assert set(scalers).issubset(
"Expect scalers to be a subset of ['identity', 'amplification', 'attenuation'], \ {"identity", "amplification", "attenuation"}
got {}".format(scalers) ), "Expect scalers to be a subset of ['identity', 'amplification', 'attenuation'], \
got {}".format(
scalers
)
self.aggregators = aggregators self.aggregators = aggregators
self.scalers = scalers self.scalers = scalers
if data_info['name'] in ['ogbg-molhiv', 'ogbg-molpcba']: if data_info["name"] in ["ogbg-molhiv", "ogbg-molpcba"]:
self.node_encoder = AtomEncoder(embed_size) self.node_encoder = AtomEncoder(embed_size)
else: else:
# Handle other datasets # Handle other datasets
self.node_encoder = nn.Linear(data_info['node_feat_size'], embed_size) self.node_encoder = nn.Linear(
self.conv_layers = nn.ModuleList([SimplePNAConv(feat_size=embed_size, data_info["node_feat_size"], embed_size
)
self.conv_layers = nn.ModuleList(
[
SimplePNAConv(
feat_size=embed_size,
aggregators=aggregators, aggregators=aggregators,
scalers=scalers, scalers=scalers,
delta=data_info['delta'], delta=data_info["delta"],
dropout=dropout, dropout=dropout,
batch_norm=batch_norm, batch_norm=batch_norm,
residual=residual, residual=residual,
num_mlp_layers=num_mlp_layers) num_mlp_layers=num_mlp_layers,
for _ in range(num_layers)]) )
for _ in range(num_layers)
]
)
if readout == 'sum': if readout == "sum":
self.pool = SumPooling() self.pool = SumPooling()
elif readout == 'mean': elif readout == "mean":
self.pool = AvgPooling() self.pool = AvgPooling()
else: else:
raise ValueError("Expect readout to be 'sum' or 'mean', got {}".format(readout)) raise ValueError(
self.pred = MLP(embed_size, data_info['out_size'], decreasing_hidden_size=True) "Expect readout to be 'sum' or 'mean', got {}".format(readout)
)
self.pred = MLP(
embed_size, data_info["out_size"], decreasing_hidden_size=True
)
def forward(self, graph, node_feat, edge_feat=None): def forward(self, graph, node_feat, edge_feat=None):
hn = self.node_encoder(node_feat) hn = self.node_encoder(node_feat)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment