Unverified Commit dce89919 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Auto-reformat multiple python folders. (#5325)



* auto-reformat

* lintrunner

---------
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent ab812179
from ...utils.factory import NodeModelFactory
from .gcn import GCN
from .gat import GAT from .gat import GAT
from .gcn import GCN
from .gin import GIN
from .sage import GraphSAGE from .sage import GraphSAGE
from .sgc import SGC from .sgc import SGC
from .gin import GIN
from ...utils.factory import NodeModelFactory
NodeModelFactory.register("gcn")(GCN) NodeModelFactory.register("gcn")(GCN)
NodeModelFactory.register("gat")(GAT) NodeModelFactory.register("gat")(GAT)
......
from typing import List from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.nn import GATConv
from dgl.base import dgl_warning from dgl.base import dgl_warning
from dgl.nn import GATConv
class GAT(nn.Module): class GAT(nn.Module):
def __init__(self, def __init__(
self,
data_info: dict, data_info: dict,
embed_size: int = -1, embed_size: int = -1,
num_layers: int = 2, num_layers: int = 2,
...@@ -17,7 +19,8 @@ class GAT(nn.Module): ...@@ -17,7 +19,8 @@ class GAT(nn.Module):
feat_drop: float = 0.6, feat_drop: float = 0.6,
attn_drop: float = 0.6, attn_drop: float = 0.6,
negative_slope: float = 0.2, negative_slope: float = 0.2,
residual: bool = False): residual: bool = False,
):
"""Graph Attention Networks """Graph Attention Networks
Parameters Parameters
...@@ -57,19 +60,30 @@ class GAT(nn.Module): ...@@ -57,19 +60,30 @@ class GAT(nn.Module):
in_size = data_info["in_size"] in_size = data_info["in_size"]
for i in range(num_layers): for i in range(num_layers):
in_hidden = hidden_size*heads[i-1] if i > 0 else in_size in_hidden = hidden_size * heads[i - 1] if i > 0 else in_size
out_hidden = hidden_size if i < num_layers - \ out_hidden = (
1 else data_info["out_size"] hidden_size if i < num_layers - 1 else data_info["out_size"]
)
activation = None if i == num_layers - 1 else self.activation activation = None if i == num_layers - 1 else self.activation
self.gat_layers.append(GATConv( self.gat_layers.append(
in_hidden, out_hidden, heads[i], GATConv(
feat_drop, attn_drop, negative_slope, residual, activation)) in_hidden,
out_hidden,
heads[i],
feat_drop,
attn_drop,
negative_slope,
residual,
activation,
)
)
def forward(self, graph, node_feat, edge_feat=None): def forward(self, graph, node_feat, edge_feat=None):
if self.embed_size > 0: if self.embed_size > 0:
dgl_warning( dgl_warning(
"The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size.") "The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size."
)
h = self.embed.weight h = self.embed.weight
else: else:
h = node_feat h = node_feat
......
import dgl
import torch import torch
import torch.nn as nn import torch.nn as nn
import dgl
from dgl.base import dgl_warning from dgl.base import dgl_warning
class GCN(nn.Module): class GCN(nn.Module):
def __init__(self, def __init__(
self,
data_info: dict, data_info: dict,
embed_size: int = -1, embed_size: int = -1,
hidden_size: int = 16, hidden_size: int = 16,
...@@ -12,7 +14,8 @@ class GCN(nn.Module): ...@@ -12,7 +14,8 @@ class GCN(nn.Module):
norm: str = "both", norm: str = "both",
activation: str = "relu", activation: str = "relu",
dropout: float = 0.5, dropout: float = 0.5,
use_edge_weight: bool = False): use_edge_weight: bool = False,
):
"""Graph Convolutional Networks """Graph Convolutional Networks
Parameters Parameters
...@@ -47,28 +50,36 @@ class GCN(nn.Module): ...@@ -47,28 +50,36 @@ class GCN(nn.Module):
for i in range(num_layers): for i in range(num_layers):
in_hidden = hidden_size if i > 0 else in_size in_hidden = hidden_size if i > 0 else in_size
out_hidden = hidden_size if i < num_layers - 1 else data_info["out_size"] out_hidden = (
hidden_size if i < num_layers - 1 else data_info["out_size"]
)
self.layers.append(dgl.nn.GraphConv(in_hidden, out_hidden, norm=norm, allow_zero_in_degree=True)) self.layers.append(
dgl.nn.GraphConv(
in_hidden, out_hidden, norm=norm, allow_zero_in_degree=True
)
)
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
self.act = getattr(torch, activation) self.act = getattr(torch, activation)
def forward(self, g, node_feat, edge_feat = None): def forward(self, g, node_feat, edge_feat=None):
if self.embed_size > 0: if self.embed_size > 0:
dgl_warning("The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size.") dgl_warning(
"The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size."
)
h = self.embed.weight h = self.embed.weight
else: else:
h = node_feat h = node_feat
edge_weight = edge_feat if self.use_edge_weight else None edge_weight = edge_feat if self.use_edge_weight else None
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
h = layer(g, h, edge_weight=edge_weight) h = layer(g, h, edge_weight=edge_weight)
if l != len(self.layers) -1: if l != len(self.layers) - 1:
h = self.act(h) h = self.act(h)
h = self.dropout(h) h = self.dropout(h)
return h return h
def forward_block(self, blocks, node_feat, edge_feat = None): def forward_block(self, blocks, node_feat, edge_feat=None):
h = node_feat h = node_feat
edge_weight = edge_feat if self.use_edge_weight else None edge_weight = edge_feat if self.use_edge_weight else None
for l, (layer, block) in enumerate(zip(self.layers, blocks)): for l, (layer, block) in enumerate(zip(self.layers, blocks)):
......
import torch.nn as nn import torch.nn as nn
from dgl.nn import GINConv
from dgl.base import dgl_warning from dgl.base import dgl_warning
from dgl.nn import GINConv
class GIN(nn.Module): class GIN(nn.Module):
def __init__(self, def __init__(
self,
data_info: dict, data_info: dict,
embed_size: int = -1, embed_size: int = -1,
hidden_size=64, hidden_size=64,
num_layers=3, num_layers=3,
aggregator_type='sum'): aggregator_type="sum",
):
"""Graph Isomophism Networks """Graph Isomophism Networks
Edge feature is ignored in this model. Edge feature is ignored in this model.
...@@ -39,9 +41,13 @@ class GIN(nn.Module): ...@@ -39,9 +41,13 @@ class GIN(nn.Module):
in_size = data_info["in_size"] in_size = data_info["in_size"]
for i in range(num_layers): for i in range(num_layers):
input_dim = in_size if i == 0 else hidden_size input_dim = in_size if i == 0 else hidden_size
mlp = nn.Sequential(nn.Linear(input_dim, hidden_size), mlp = nn.Sequential(
nn.BatchNorm1d(hidden_size), nn.ReLU(), nn.Linear(input_dim, hidden_size),
nn.Linear(hidden_size, hidden_size), nn.ReLU()) nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
)
self.conv_list.append(GINConv(mlp, aggregator_type, 1e-5, True)) self.conv_list.append(GINConv(mlp, aggregator_type, 1e-5, True))
self.out_mlp = nn.Linear(hidden_size, data_info["out_size"]) self.out_mlp = nn.Linear(hidden_size, data_info["out_size"])
...@@ -49,7 +55,8 @@ class GIN(nn.Module): ...@@ -49,7 +55,8 @@ class GIN(nn.Module):
def forward(self, graph, node_feat, edge_feat=None): def forward(self, graph, node_feat, edge_feat=None):
if self.embed_size > 0: if self.embed_size > 0:
dgl_warning( dgl_warning(
"The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size.") "The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size."
)
h = self.embed.weight h = self.embed.weight
else: else:
h = node_feat h = node_feat
......
import torch.nn as nn
import dgl import dgl
import torch.nn as nn
from dgl.base import dgl_warning from dgl.base import dgl_warning
class GraphSAGE(nn.Module): class GraphSAGE(nn.Module):
def __init__(self, def __init__(
self,
data_info: dict, data_info: dict,
embed_size: int = -1, embed_size: int = -1,
hidden_size: int = 16, hidden_size: int = 16,
num_layers: int = 1, num_layers: int = 1,
activation: str = "relu", activation: str = "relu",
dropout: float = 0.5, dropout: float = 0.5,
aggregator_type: str = "gcn"): aggregator_type: str = "gcn",
):
"""GraphSAGE model """GraphSAGE model
Parameters Parameters
...@@ -44,12 +47,18 @@ class GraphSAGE(nn.Module): ...@@ -44,12 +47,18 @@ class GraphSAGE(nn.Module):
for i in range(num_layers): for i in range(num_layers):
in_hidden = hidden_size if i > 0 else in_size in_hidden = hidden_size if i > 0 else in_size
out_hidden = hidden_size if i < num_layers - 1 else data_info["out_size"] out_hidden = (
self.layers.append(dgl.nn.SAGEConv(in_hidden, out_hidden, aggregator_type)) hidden_size if i < num_layers - 1 else data_info["out_size"]
)
self.layers.append(
dgl.nn.SAGEConv(in_hidden, out_hidden, aggregator_type)
)
def forward(self, graph, node_feat, edge_feat = None): def forward(self, graph, node_feat, edge_feat=None):
if self.embed_size > 0: if self.embed_size > 0:
dgl_warning("The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size.") dgl_warning(
"The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size."
)
h = self.embed.weight h = self.embed.weight
else: else:
h = node_feat h = node_feat
...@@ -61,7 +70,7 @@ class GraphSAGE(nn.Module): ...@@ -61,7 +70,7 @@ class GraphSAGE(nn.Module):
h = self.dropout(h) h = self.dropout(h)
return h return h
def forward_block(self, blocks, node_feat, edge_feat = None): def forward_block(self, blocks, node_feat, edge_feat=None):
h = node_feat h = node_feat
for l, (layer, block) in enumerate(zip(self.layers, blocks)): for l, (layer, block) in enumerate(zip(self.layers, blocks)):
h = layer(block, h, edge_feat) h = layer(block, h, edge_feat)
......
import torch.nn as nn
import dgl.function as fn import dgl.function as fn
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.nn import SGConv
from dgl.base import dgl_warning from dgl.base import dgl_warning
from dgl.nn import SGConv
class SGC(nn.Module): class SGC(nn.Module):
def __init__(self, def __init__(self, data_info: dict, embed_size: int = -1, bias=True, k=2):
data_info: dict, """Simplifying Graph Convolutional Networks
embed_size: int = -1,
bias=True, k=2):
""" Simplifying Graph Convolutional Networks
Edge feature is ignored in this model. Edge feature is ignored in this model.
...@@ -34,12 +31,20 @@ class SGC(nn.Module): ...@@ -34,12 +31,20 @@ class SGC(nn.Module):
in_size = embed_size in_size = embed_size
else: else:
in_size = data_info["in_size"] in_size = data_info["in_size"]
self.sgc = SGConv(in_size, self.out_size, k=k, cached=True, self.sgc = SGConv(
bias=bias, norm=self.normalize) in_size,
self.out_size,
k=k,
cached=True,
bias=bias,
norm=self.normalize,
)
def forward(self, g, node_feat, edge_feat=None): def forward(self, g, node_feat, edge_feat=None):
if self.embed_size > 0: if self.embed_size > 0:
dgl_warning("The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size.") dgl_warning(
"The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size."
)
h = self.embed.weight h = self.embed.weight
else: else:
h = node_feat h = node_feat
...@@ -47,4 +52,4 @@ class SGC(nn.Module): ...@@ -47,4 +52,4 @@ class SGC(nn.Module):
@staticmethod @staticmethod
def normalize(h): def normalize(h):
return (h-h.mean(0))/(h.std(0) + 1e-5) return (h - h.mean(0)) / (h.std(0) + 1e-5)
from .graphpred import GraphpredPipeline
from .linkpred import LinkpredPipeline
from .nodepred import NodepredPipeline from .nodepred import NodepredPipeline
from .nodepred_sample import NodepredNsPipeline from .nodepred_sample import NodepredNsPipeline
from .linkpred import LinkpredPipeline
from .graphpred import GraphpredPipeline
\ No newline at end of file
from pathlib import Path
from jinja2 import Template
import copy import copy
from pathlib import Path
from typing import Optional
import ruamel.yaml
import typer import typer
from jinja2 import Template
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Optional
from ...utils.factory import PipelineFactory, GraphModelFactory, PipelineBase, DataFactory from ...utils.factory import (
DataFactory,
GraphModelFactory,
PipelineBase,
PipelineFactory,
)
from ...utils.yaml_dump import deep_convert_dict, merge_comment from ...utils.yaml_dump import deep_convert_dict, merge_comment
import ruamel.yaml
pipeline_comments = { pipeline_comments = {
"num_runs": "Number of experiments to run", "num_runs": "Number of experiments to run",
...@@ -14,9 +21,10 @@ pipeline_comments = { ...@@ -14,9 +21,10 @@ pipeline_comments = {
"eval_batch_size": "Graph batch size when evaluating", "eval_batch_size": "Graph batch size when evaluating",
"num_workers": "Number of workers for data loading", "num_workers": "Number of workers for data loading",
"num_epochs": "Number of training epochs", "num_epochs": "Number of training epochs",
"save_path": "Directory to save the experiment results" "save_path": "Directory to save the experiment results",
} }
class GraphpredPipelineCfg(BaseModel): class GraphpredPipelineCfg(BaseModel):
num_runs: int = 1 num_runs: int = 1
train_batch_size: int = 32 train_batch_size: int = 32
...@@ -30,20 +38,23 @@ class GraphpredPipelineCfg(BaseModel): ...@@ -30,20 +38,23 @@ class GraphpredPipelineCfg(BaseModel):
num_epochs: int = 100 num_epochs: int = 100
save_path: str = "results" save_path: str = "results"
@PipelineFactory.register("graphpred") @PipelineFactory.register("graphpred")
class GraphpredPipeline(PipelineBase): class GraphpredPipeline(PipelineBase):
def __init__(self): def __init__(self):
self.pipeline = { self.pipeline = {"name": "graphpred", "mode": "train"}
"name": "graphpred",
"mode": "train"
}
@classmethod @classmethod
def setup_user_cfg_cls(cls): def setup_user_cfg_cls(cls):
from ...utils.enter_config import UserConfig from ...utils.enter_config import UserConfig
class GraphPredUserConfig(UserConfig): class GraphPredUserConfig(UserConfig):
data: DataFactory.filter("graphpred").get_pydantic_config() = Field(..., discriminator="name") data: DataFactory.filter("graphpred").get_pydantic_config() = Field(
model: GraphModelFactory.get_pydantic_model_config() = Field(..., discriminator="name") ..., discriminator="name"
)
model: GraphModelFactory.get_pydantic_model_config() = Field(
..., discriminator="name"
)
general_pipeline: GraphpredPipelineCfg = GraphpredPipelineCfg() general_pipeline: GraphpredPipelineCfg = GraphpredPipelineCfg()
cls.user_cfg_cls = GraphPredUserConfig cls.user_cfg_cls = GraphPredUserConfig
...@@ -54,10 +65,15 @@ class GraphpredPipeline(PipelineBase): ...@@ -54,10 +65,15 @@ class GraphpredPipeline(PipelineBase):
def get_cfg_func(self): def get_cfg_func(self):
def config( def config(
data: DataFactory.filter("graphpred").get_dataset_enum() = typer.Option(..., help="input data name"), data: DataFactory.filter(
"graphpred"
).get_dataset_enum() = typer.Option(..., help="input data name"),
cfg: Optional[str] = typer.Option( cfg: Optional[str] = typer.Option(
None, help="output configuration path"), None, help="output configuration path"
model: GraphModelFactory.get_model_enum() = typer.Option(..., help="Model name"), ),
model: GraphModelFactory.get_model_enum() = typer.Option(
..., help="Model name"
),
): ):
self.__class__.setup_user_cfg_cls() self.__class__.setup_user_cfg_cls()
generated_cfg = { generated_cfg = {
...@@ -66,17 +82,19 @@ class GraphpredPipeline(PipelineBase): ...@@ -66,17 +82,19 @@ class GraphpredPipeline(PipelineBase):
"device": "cpu", "device": "cpu",
"data": {"name": data.name}, "data": {"name": data.name},
"model": {"name": model.value}, "model": {"name": model.value},
"general_pipeline": {} "general_pipeline": {},
} }
output_cfg = self.user_cfg_cls(**generated_cfg).dict() output_cfg = self.user_cfg_cls(**generated_cfg).dict()
output_cfg = deep_convert_dict(output_cfg) output_cfg = deep_convert_dict(output_cfg)
comment_dict = { comment_dict = {
"device": "Torch device name, e.g., cpu or cuda or cuda:0", "device": "Torch device name, e.g., cpu or cuda or cuda:0",
"data": { "data": {
"split_ratio": 'Ratio to generate data split, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset' "split_ratio": "Ratio to generate data split, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset"
}, },
"general_pipeline": pipeline_comments, "general_pipeline": pipeline_comments,
"model": GraphModelFactory.get_constructor_doc_dict(model.value) "model": GraphModelFactory.get_constructor_doc_dict(
model.value
),
} }
comment_dict = merge_comment(output_cfg, comment_dict) comment_dict = merge_comment(output_cfg, comment_dict)
...@@ -84,7 +102,11 @@ class GraphpredPipeline(PipelineBase): ...@@ -84,7 +102,11 @@ class GraphpredPipeline(PipelineBase):
if cfg is None: if cfg is None:
cfg = "_".join(["graphpred", data.value, model.value]) + ".yaml" cfg = "_".join(["graphpred", data.value, model.value]) + ".yaml"
yaml.dump(comment_dict, Path(cfg).open("w")) yaml.dump(comment_dict, Path(cfg).open("w"))
print("Configuration file is generated at {}".format(Path(cfg).absolute())) print(
"Configuration file is generated at {}".format(
Path(cfg).absolute()
)
)
return config return config
...@@ -97,11 +119,17 @@ class GraphpredPipeline(PipelineBase): ...@@ -97,11 +119,17 @@ class GraphpredPipeline(PipelineBase):
render_cfg = copy.deepcopy(user_cfg_dict) render_cfg = copy.deepcopy(user_cfg_dict)
model_code = GraphModelFactory.get_source_code( model_code = GraphModelFactory.get_source_code(
user_cfg_dict["model"]["name"]) user_cfg_dict["model"]["name"]
)
render_cfg["model_code"] = model_code render_cfg["model_code"] = model_code
render_cfg["model_class_name"] = GraphModelFactory.get_model_class_name( render_cfg["model_class_name"] = GraphModelFactory.get_model_class_name(
user_cfg_dict["model"]["name"]) user_cfg_dict["model"]["name"]
render_cfg.update(DataFactory.get_generated_code_dict(user_cfg_dict["data"]["name"], '**cfg["data"]')) )
render_cfg.update(
DataFactory.get_generated_code_dict(
user_cfg_dict["data"]["name"], '**cfg["data"]'
)
)
generated_user_cfg = copy.deepcopy(user_cfg_dict) generated_user_cfg = copy.deepcopy(user_cfg_dict)
if "split_ratio" in generated_user_cfg["data"]: if "split_ratio" in generated_user_cfg["data"]:
...@@ -109,7 +137,9 @@ class GraphpredPipeline(PipelineBase): ...@@ -109,7 +137,9 @@ class GraphpredPipeline(PipelineBase):
generated_user_cfg["data_name"] = generated_user_cfg["data"].pop("name") generated_user_cfg["data_name"] = generated_user_cfg["data"].pop("name")
generated_user_cfg.pop("pipeline_name") generated_user_cfg.pop("pipeline_name")
generated_user_cfg.pop("pipeline_mode") generated_user_cfg.pop("pipeline_mode")
generated_user_cfg["model_name"] = generated_user_cfg["model"].pop("name") generated_user_cfg["model_name"] = generated_user_cfg["model"].pop(
"name"
)
generated_user_cfg["general_pipeline"]["optimizer"].pop("name") generated_user_cfg["general_pipeline"]["optimizer"].pop("name")
generated_user_cfg["general_pipeline"]["lr_scheduler"].pop("name") generated_user_cfg["general_pipeline"]["lr_scheduler"].pop("name")
...@@ -118,7 +148,10 @@ class GraphpredPipeline(PipelineBase): ...@@ -118,7 +148,10 @@ class GraphpredPipeline(PipelineBase):
generated_train_cfg["lr_scheduler"].pop("name") generated_train_cfg["lr_scheduler"].pop("name")
if user_cfg_dict["data"].get("split_ratio", None) is not None: if user_cfg_dict["data"].get("split_ratio", None) is not None:
render_cfg["data_initialize_code"] = "{}, split_ratio={}".format(render_cfg["data_initialize_code"], user_cfg_dict["data"]["split_ratio"]) render_cfg["data_initialize_code"] = "{}, split_ratio={}".format(
render_cfg["data_initialize_code"],
user_cfg_dict["data"]["split_ratio"],
)
render_cfg["user_cfg_str"] = f"cfg = {str(generated_user_cfg)}" render_cfg["user_cfg_str"] = f"cfg = {str(generated_user_cfg)}"
render_cfg["user_cfg"] = user_cfg_dict render_cfg["user_cfg"] = user_cfg_dict
return template.render(**render_cfg) return template.render(**render_cfg)
......
from pathlib import Path
from jinja2 import Template
import copy import copy
import typer from pathlib import Path
from pydantic import BaseModel, Field
from typing import Optional from typing import Optional
import yaml
from ...utils.factory import PipelineFactory, NodeModelFactory, PipelineBase, DataFactory, EdgeModelFactory, NegativeSamplerFactory
from ...utils.base_model import EarlyStopConfig, DeviceEnum
from ...utils.yaml_dump import deep_convert_dict, merge_comment
import ruamel.yaml import ruamel.yaml
import typer
import yaml
from jinja2 import Template
from pydantic import BaseModel, Field
from ruamel.yaml.comments import CommentedMap from ruamel.yaml.comments import CommentedMap
from ...utils.base_model import DeviceEnum, EarlyStopConfig
from ...utils.factory import (
DataFactory,
EdgeModelFactory,
NegativeSamplerFactory,
NodeModelFactory,
PipelineBase,
PipelineFactory,
)
from ...utils.yaml_dump import deep_convert_dict, merge_comment
class LinkpredPipelineCfg(BaseModel): class LinkpredPipelineCfg(BaseModel):
hidden_size: int = 256 hidden_size: int = 256
eval_batch_size: int = 32769 eval_batch_size: int = 32769
...@@ -42,23 +52,25 @@ class LinkpredPipeline(PipelineBase): ...@@ -42,23 +52,25 @@ class LinkpredPipeline(PipelineBase):
pipeline_name = "linkpred" pipeline_name = "linkpred"
def __init__(self): def __init__(self):
self.pipeline = { self.pipeline = {"name": "linkpred", "mode": "train"}
"name": "linkpred",
"mode": "train"
}
@classmethod @classmethod
def setup_user_cfg_cls(cls): def setup_user_cfg_cls(cls):
from ...utils.enter_config import UserConfig from ...utils.enter_config import UserConfig
class LinkPredUserConfig(UserConfig): class LinkPredUserConfig(UserConfig):
data: DataFactory.filter("linkpred").get_pydantic_config() = Field(..., discriminator="name") data: DataFactory.filter("linkpred").get_pydantic_config() = Field(
node_model: NodeModelFactory.get_pydantic_model_config() = Field(..., ..., discriminator="name"
discriminator="name") )
edge_model: EdgeModelFactory.get_pydantic_model_config() = Field(..., node_model: NodeModelFactory.get_pydantic_model_config() = Field(
discriminator="name") ..., discriminator="name"
neg_sampler: NegativeSamplerFactory.get_pydantic_model_config() = Field(..., )
discriminator="name") edge_model: EdgeModelFactory.get_pydantic_model_config() = Field(
..., discriminator="name"
)
neg_sampler: NegativeSamplerFactory.get_pydantic_model_config() = (
Field(..., discriminator="name")
)
general_pipeline: LinkpredPipelineCfg = LinkpredPipelineCfg() general_pipeline: LinkpredPipelineCfg = LinkpredPipelineCfg()
cls.user_cfg_cls = LinkPredUserConfig cls.user_cfg_cls = LinkPredUserConfig
...@@ -69,15 +81,21 @@ class LinkpredPipeline(PipelineBase): ...@@ -69,15 +81,21 @@ class LinkpredPipeline(PipelineBase):
def get_cfg_func(self): def get_cfg_func(self):
def config( def config(
data: DataFactory.filter("linkpred").get_dataset_enum() = typer.Option(..., help="input data name"), data: DataFactory.filter(
"linkpred"
).get_dataset_enum() = typer.Option(..., help="input data name"),
cfg: str = typer.Option( cfg: str = typer.Option(
"cfg.yaml", help="output configuration path"), "cfg.yaml", help="output configuration path"
node_model: NodeModelFactory.get_model_enum() = typer.Option(..., ),
help="Model name"), node_model: NodeModelFactory.get_model_enum() = typer.Option(
edge_model: EdgeModelFactory.get_model_enum() = typer.Option(..., ..., help="Model name"
help="Model name"), ),
edge_model: EdgeModelFactory.get_model_enum() = typer.Option(
..., help="Model name"
),
neg_sampler: NegativeSamplerFactory.get_model_enum() = typer.Option( neg_sampler: NegativeSamplerFactory.get_model_enum() = typer.Option(
"persource", help="Negative sampler name"), "persource", help="Negative sampler name"
),
): ):
self.__class__.setup_user_cfg_cls() self.__class__.setup_user_cfg_cls()
generated_cfg = { generated_cfg = {
...@@ -94,21 +112,41 @@ class LinkpredPipeline(PipelineBase): ...@@ -94,21 +112,41 @@ class LinkpredPipeline(PipelineBase):
comment_dict = { comment_dict = {
"device": "Torch device name, e.g., cpu or cuda or cuda:0", "device": "Torch device name, e.g., cpu or cuda or cuda:0",
"general_pipeline": pipeline_comments, "general_pipeline": pipeline_comments,
"node_model": NodeModelFactory.get_constructor_doc_dict(node_model.value), "node_model": NodeModelFactory.get_constructor_doc_dict(
"edge_model": EdgeModelFactory.get_constructor_doc_dict(edge_model.value), node_model.value
"neg_sampler": NegativeSamplerFactory.get_constructor_doc_dict(neg_sampler.value), ),
"edge_model": EdgeModelFactory.get_constructor_doc_dict(
edge_model.value
),
"neg_sampler": NegativeSamplerFactory.get_constructor_doc_dict(
neg_sampler.value
),
"data": { "data": {
"split_ratio": 'List of float, e.q. [0.8, 0.1, 0.1]. Split ratios for training, validation and test sets. Must sum to one. Leave blank to use builtin split in original dataset', "split_ratio": "List of float, e.q. [0.8, 0.1, 0.1]. Split ratios for training, validation and test sets. Must sum to one. Leave blank to use builtin split in original dataset",
"neg_ratio": 'Int, e.q. 2. Indicate how much negative samples to be sampled per positive samples. Leave blank to use builtin split in original dataset' "neg_ratio": "Int, e.q. 2. Indicate how much negative samples to be sampled per positive samples. Leave blank to use builtin split in original dataset",
}, },
} }
comment_dict = merge_comment(output_cfg, comment_dict) comment_dict = merge_comment(output_cfg, comment_dict)
if cfg is None: if cfg is None:
cfg = "_".join(["linkpred", data.value, node_model.value, edge_model.value]) + ".yaml" cfg = (
"_".join(
[
"linkpred",
data.value,
node_model.value,
edge_model.value,
]
)
+ ".yaml"
)
yaml = ruamel.yaml.YAML() yaml = ruamel.yaml.YAML()
yaml.dump(comment_dict, Path(cfg).open("w")) yaml.dump(comment_dict, Path(cfg).open("w"))
print("Configuration file is generated at {}".format(Path(cfg).absolute())) print(
"Configuration file is generated at {}".format(
Path(cfg).absolute()
)
)
return config return config
...@@ -123,18 +161,33 @@ class LinkpredPipeline(PipelineBase): ...@@ -123,18 +161,33 @@ class LinkpredPipeline(PipelineBase):
render_cfg = copy.deepcopy(user_cfg_dict) render_cfg = copy.deepcopy(user_cfg_dict)
render_cfg["node_model_code"] = NodeModelFactory.get_source_code( render_cfg["node_model_code"] = NodeModelFactory.get_source_code(
user_cfg_dict["node_model"]["name"]) user_cfg_dict["node_model"]["name"]
)
render_cfg["edge_model_code"] = EdgeModelFactory.get_source_code( render_cfg["edge_model_code"] = EdgeModelFactory.get_source_code(
user_cfg_dict["edge_model"]["name"]) user_cfg_dict["edge_model"]["name"]
render_cfg["node_model_class_name"] = NodeModelFactory.get_model_class_name( )
user_cfg_dict["node_model"]["name"]) render_cfg[
render_cfg["edge_model_class_name"] = EdgeModelFactory.get_model_class_name( "node_model_class_name"
user_cfg_dict["edge_model"]["name"]) ] = NodeModelFactory.get_model_class_name(
render_cfg["neg_sampler_name"] = NegativeSamplerFactory.get_model_class_name( user_cfg_dict["node_model"]["name"]
user_cfg_dict["neg_sampler"]["name"]) )
render_cfg[
"edge_model_class_name"
] = EdgeModelFactory.get_model_class_name(
user_cfg_dict["edge_model"]["name"]
)
render_cfg[
"neg_sampler_name"
] = NegativeSamplerFactory.get_model_class_name(
user_cfg_dict["neg_sampler"]["name"]
)
render_cfg["loss"] = user_cfg_dict["general_pipeline"]["loss"] render_cfg["loss"] = user_cfg_dict["general_pipeline"]["loss"]
# update import and initialization code # update import and initialization code
render_cfg.update(DataFactory.get_generated_code_dict(user_cfg_dict["data"]["name"], '**cfg["data"]')) render_cfg.update(
DataFactory.get_generated_code_dict(
user_cfg_dict["data"]["name"], '**cfg["data"]'
)
)
generated_user_cfg = copy.deepcopy(user_cfg_dict) generated_user_cfg = copy.deepcopy(user_cfg_dict)
if len(generated_user_cfg["data"]) == 1: if len(generated_user_cfg["data"]) == 1:
generated_user_cfg.pop("data") generated_user_cfg.pop("data")
...@@ -150,10 +203,17 @@ class LinkpredPipeline(PipelineBase): ...@@ -150,10 +203,17 @@ class LinkpredPipeline(PipelineBase):
generated_train_cfg = copy.deepcopy(user_cfg_dict["general_pipeline"]) generated_train_cfg = copy.deepcopy(user_cfg_dict["general_pipeline"])
generated_train_cfg["optimizer"].pop("name") generated_train_cfg["optimizer"].pop("name")
if user_cfg_dict["data"].get("split_ratio", None) is not None: if user_cfg_dict["data"].get("split_ratio", None) is not None:
assert user_cfg_dict["data"].get("neg_ratio", None) is not None, "Please specify both split_ratio and neg_ratio" assert (
render_cfg["data_initialize_code"] = "{}, split_ratio={}, neg_ratio={}".format(render_cfg["data_initialize_code"], user_cfg_dict["data"]["split_ratio"], user_cfg_dict["data"]["neg_ratio"]) user_cfg_dict["data"].get("neg_ratio", None) is not None
), "Please specify both split_ratio and neg_ratio"
render_cfg[
"data_initialize_code"
] = "{}, split_ratio={}, neg_ratio={}".format(
render_cfg["data_initialize_code"],
user_cfg_dict["data"]["split_ratio"],
user_cfg_dict["data"]["neg_ratio"],
)
generated_user_cfg["data"].pop("split_ratio") generated_user_cfg["data"].pop("split_ratio")
generated_user_cfg["data"].pop("neg_ratio") generated_user_cfg["data"].pop("neg_ratio")
......
from pathlib import Path
from jinja2 import Template
import copy import copy
import typer from pathlib import Path
from pydantic import BaseModel, Field
from typing import Optional from typing import Optional
import yaml
from ...utils.factory import PipelineFactory, NodeModelFactory, PipelineBase, DataFactory
from ...utils.base_model import EarlyStopConfig, DeviceEnum
from ...utils.yaml_dump import deep_convert_dict, merge_comment
import ruamel.yaml import ruamel.yaml
import typer
import yaml
from jinja2 import Template
from pydantic import BaseModel, Field
from ruamel.yaml.comments import CommentedMap from ruamel.yaml.comments import CommentedMap
from ...utils.base_model import DeviceEnum, EarlyStopConfig
from ...utils.factory import (
DataFactory,
NodeModelFactory,
PipelineBase,
PipelineFactory,
)
from ...utils.yaml_dump import deep_convert_dict, merge_comment
pipeline_comments = { pipeline_comments = {
"num_epochs": "Number of training epochs", "num_epochs": "Number of training epochs",
"eval_period": "Interval epochs between evaluations", "eval_period": "Interval epochs between evaluations",
"early_stop": { "early_stop": {
"patience": "Steps before early stop", "patience": "Steps before early stop",
"checkpoint_path": "Early stop checkpoint model file path" "checkpoint_path": "Early stop checkpoint model file path",
}, },
"save_path": "Directory to save the experiment results", "save_path": "Directory to save the experiment results",
"num_runs": "Number of experiments to run", "num_runs": "Number of experiments to run",
} }
class NodepredPipelineCfg(BaseModel): class NodepredPipelineCfg(BaseModel):
early_stop: Optional[EarlyStopConfig] = EarlyStopConfig() early_stop: Optional[EarlyStopConfig] = EarlyStopConfig()
num_epochs: int = 200 num_epochs: int = 200
...@@ -31,23 +39,26 @@ class NodepredPipelineCfg(BaseModel): ...@@ -31,23 +39,26 @@ class NodepredPipelineCfg(BaseModel):
save_path: str = "results" save_path: str = "results"
num_runs: int = 1 num_runs: int = 1
@PipelineFactory.register("nodepred") @PipelineFactory.register("nodepred")
class NodepredPipeline(PipelineBase): class NodepredPipeline(PipelineBase):
user_cfg_cls = None user_cfg_cls = None
def __init__(self): def __init__(self):
self.pipeline = { self.pipeline = {"name": "nodepred", "mode": "train"}
"name": "nodepred",
"mode": "train"
}
@classmethod @classmethod
def setup_user_cfg_cls(cls): def setup_user_cfg_cls(cls):
from ...utils.enter_config import UserConfig from ...utils.enter_config import UserConfig
class NodePredUserConfig(UserConfig): class NodePredUserConfig(UserConfig):
data: DataFactory.filter("nodepred").get_pydantic_config() = Field(..., discriminator="name") data: DataFactory.filter("nodepred").get_pydantic_config() = Field(
model : NodeModelFactory.get_pydantic_model_config() = Field(..., discriminator="name") ..., discriminator="name"
)
model: NodeModelFactory.get_pydantic_model_config() = Field(
..., discriminator="name"
)
general_pipeline: NodepredPipelineCfg = NodepredPipelineCfg() general_pipeline: NodepredPipelineCfg = NodepredPipelineCfg()
cls.user_cfg_cls = NodePredUserConfig cls.user_cfg_cls = NodePredUserConfig
...@@ -58,10 +69,15 @@ class NodepredPipeline(PipelineBase): ...@@ -58,10 +69,15 @@ class NodepredPipeline(PipelineBase):
def get_cfg_func(self): def get_cfg_func(self):
def config( def config(
data: DataFactory.filter("nodepred").get_dataset_enum() = typer.Option(..., help="input data name"), data: DataFactory.filter(
"nodepred"
).get_dataset_enum() = typer.Option(..., help="input data name"),
cfg: Optional[str] = typer.Option( cfg: Optional[str] = typer.Option(
None, help="output configuration path"), None, help="output configuration path"
model: NodeModelFactory.get_model_enum() = typer.Option(..., help="Model name"), ),
model: NodeModelFactory.get_model_enum() = typer.Option(
..., help="Model name"
),
): ):
self.__class__.setup_user_cfg_cls() self.__class__.setup_user_cfg_cls()
generated_cfg = { generated_cfg = {
...@@ -70,17 +86,17 @@ class NodepredPipeline(PipelineBase): ...@@ -70,17 +86,17 @@ class NodepredPipeline(PipelineBase):
"device": "cpu", "device": "cpu",
"data": {"name": data.name}, "data": {"name": data.name},
"model": {"name": model.value}, "model": {"name": model.value},
"general_pipeline": {} "general_pipeline": {},
} }
output_cfg = self.user_cfg_cls(**generated_cfg).dict() output_cfg = self.user_cfg_cls(**generated_cfg).dict()
output_cfg = deep_convert_dict(output_cfg) output_cfg = deep_convert_dict(output_cfg)
comment_dict = { comment_dict = {
"device": "Torch device name, e.g., cpu or cuda or cuda:0", "device": "Torch device name, e.g., cpu or cuda or cuda:0",
"data": { "data": {
"split_ratio": 'Ratio to generate split masks, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset' "split_ratio": "Ratio to generate split masks, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset"
}, },
"general_pipeline": pipeline_comments, "general_pipeline": pipeline_comments,
"model": NodeModelFactory.get_constructor_doc_dict(model.value) "model": NodeModelFactory.get_constructor_doc_dict(model.value),
} }
comment_dict = merge_comment(output_cfg, comment_dict) comment_dict = merge_comment(output_cfg, comment_dict)
...@@ -88,7 +104,11 @@ class NodepredPipeline(PipelineBase): ...@@ -88,7 +104,11 @@ class NodepredPipeline(PipelineBase):
if cfg is None: if cfg is None:
cfg = "_".join(["nodepred", data.value, model.value]) + ".yaml" cfg = "_".join(["nodepred", data.value, model.value]) + ".yaml"
yaml.dump(comment_dict, Path(cfg).open("w")) yaml.dump(comment_dict, Path(cfg).open("w"))
print("Configuration file is generated at {}".format(Path(cfg).absolute())) print(
"Configuration file is generated at {}".format(
Path(cfg).absolute()
)
)
return config return config
...@@ -103,11 +123,17 @@ class NodepredPipeline(PipelineBase): ...@@ -103,11 +123,17 @@ class NodepredPipeline(PipelineBase):
render_cfg = copy.deepcopy(user_cfg_dict) render_cfg = copy.deepcopy(user_cfg_dict)
model_code = NodeModelFactory.get_source_code( model_code = NodeModelFactory.get_source_code(
user_cfg_dict["model"]["name"]) user_cfg_dict["model"]["name"]
)
render_cfg["model_code"] = model_code render_cfg["model_code"] = model_code
render_cfg["model_class_name"] = NodeModelFactory.get_model_class_name( render_cfg["model_class_name"] = NodeModelFactory.get_model_class_name(
user_cfg_dict["model"]["name"]) user_cfg_dict["model"]["name"]
render_cfg.update(DataFactory.get_generated_code_dict(user_cfg_dict["data"]["name"], '**cfg["data"]')) )
render_cfg.update(
DataFactory.get_generated_code_dict(
user_cfg_dict["data"]["name"], '**cfg["data"]'
)
)
generated_user_cfg = copy.deepcopy(user_cfg_dict) generated_user_cfg = copy.deepcopy(user_cfg_dict)
if "split_ratio" in generated_user_cfg["data"]: if "split_ratio" in generated_user_cfg["data"]:
...@@ -115,15 +141,19 @@ class NodepredPipeline(PipelineBase): ...@@ -115,15 +141,19 @@ class NodepredPipeline(PipelineBase):
generated_user_cfg["data_name"] = generated_user_cfg["data"].pop("name") generated_user_cfg["data_name"] = generated_user_cfg["data"].pop("name")
generated_user_cfg.pop("pipeline_name") generated_user_cfg.pop("pipeline_name")
generated_user_cfg.pop("pipeline_mode") generated_user_cfg.pop("pipeline_mode")
generated_user_cfg["model_name"] = generated_user_cfg["model"].pop("name") generated_user_cfg["model_name"] = generated_user_cfg["model"].pop(
"name"
)
generated_user_cfg["general_pipeline"]["optimizer"].pop("name") generated_user_cfg["general_pipeline"]["optimizer"].pop("name")
generated_train_cfg = copy.deepcopy(user_cfg_dict["general_pipeline"]) generated_train_cfg = copy.deepcopy(user_cfg_dict["general_pipeline"])
generated_train_cfg["optimizer"].pop("name") generated_train_cfg["optimizer"].pop("name")
if user_cfg_dict["data"].get("split_ratio", None) is not None: if user_cfg_dict["data"].get("split_ratio", None) is not None:
render_cfg["data_initialize_code"] = "{}, split_ratio={}".format(render_cfg["data_initialize_code"], user_cfg_dict["data"]["split_ratio"]) render_cfg["data_initialize_code"] = "{}, split_ratio={}".format(
render_cfg["data_initialize_code"],
user_cfg_dict["data"]["split_ratio"],
)
render_cfg["user_cfg_str"] = f"cfg = {str(generated_user_cfg)}" render_cfg["user_cfg_str"] = f"cfg = {str(generated_user_cfg)}"
render_cfg["user_cfg"] = user_cfg_dict render_cfg["user_cfg"] = user_cfg_dict
return template.render(**render_cfg) return template.render(**render_cfg)
......
import copy
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Optional, List, Union from typing import List, Optional, Union
from typing_extensions import Literal
from jinja2 import Template, ext import ruamel.yaml
from pydantic import BaseModel, Field
import copy
import yaml
import typer import typer
from ...utils.factory import PipelineFactory, NodeModelFactory, PipelineBase, DataFactory import yaml
from ...utils.base_model import extract_name, EarlyStopConfig, DeviceEnum from jinja2 import ext, Template
from pydantic import BaseModel, Field
from ruamel.yaml.comments import CommentedMap
from typing_extensions import Literal
from ...utils.base_model import DeviceEnum, EarlyStopConfig, extract_name
from ...utils.factory import (
DataFactory,
NodeModelFactory,
PipelineBase,
PipelineFactory,
)
from ...utils.yaml_dump import deep_convert_dict, merge_comment from ...utils.yaml_dump import deep_convert_dict, merge_comment
import ruamel.yaml
from ruamel.yaml.comments import CommentedMap
class SamplerConfig(BaseModel): class SamplerConfig(BaseModel):
...@@ -25,8 +32,7 @@ class SamplerConfig(BaseModel): ...@@ -25,8 +32,7 @@ class SamplerConfig(BaseModel):
eval_num_workers: int = 4 eval_num_workers: int = 4
class Config: class Config:
extra = 'forbid' extra = "forbid"
pipeline_comments = { pipeline_comments = {
...@@ -34,19 +40,20 @@ pipeline_comments = { ...@@ -34,19 +40,20 @@ pipeline_comments = {
"eval_period": "Interval epochs between evaluations", "eval_period": "Interval epochs between evaluations",
"early_stop": { "early_stop": {
"patience": "Steps before early stop", "patience": "Steps before early stop",
"checkpoint_path": "Early stop checkpoint model file path" "checkpoint_path": "Early stop checkpoint model file path",
}, },
"sampler": { "sampler": {
"fan_out": "List of neighbors to sample per edge type for each GNN layer, with the i-th element being the fanout for the i-th GNN layer. Length should be the same as num_layers in model setting", "fan_out": "List of neighbors to sample per edge type for each GNN layer, with the i-th element being the fanout for the i-th GNN layer. Length should be the same as num_layers in model setting",
"batch_size": "Batch size of seed nodes in training stage", "batch_size": "Batch size of seed nodes in training stage",
"num_workers": "Number of workers to accelerate the graph data processing step", "num_workers": "Number of workers to accelerate the graph data processing step",
"eval_batch_size": "Batch size of seed nodes in training stage in evaluation stage", "eval_batch_size": "Batch size of seed nodes in training stage in evaluation stage",
"eval_num_workers": "Number of workers to accelerate the graph data processing step in evaluation stage" "eval_num_workers": "Number of workers to accelerate the graph data processing step in evaluation stage",
}, },
"save_path": "Directory to save the experiment results", "save_path": "Directory to save the experiment results",
"num_runs": "Number of experiments to run", "num_runs": "Number of experiments to run",
} }
class NodepredNSPipelineCfg(BaseModel): class NodepredNSPipelineCfg(BaseModel):
sampler: SamplerConfig = Field("neighbor") sampler: SamplerConfig = Field("neighbor")
early_stop: Optional[EarlyStopConfig] = EarlyStopConfig() early_stop: Optional[EarlyStopConfig] = EarlyStopConfig()
...@@ -57,22 +64,25 @@ class NodepredNSPipelineCfg(BaseModel): ...@@ -57,22 +64,25 @@ class NodepredNSPipelineCfg(BaseModel):
num_runs: int = 1 num_runs: int = 1
save_path: str = "results" save_path: str = "results"
@PipelineFactory.register("nodepred-ns") @PipelineFactory.register("nodepred-ns")
class NodepredNsPipeline(PipelineBase): class NodepredNsPipeline(PipelineBase):
def __init__(self): def __init__(self):
self.pipeline = { self.pipeline = {"name": "nodepred-ns", "mode": "train"}
"name": "nodepred-ns",
"mode": "train"
}
self.default_cfg = None self.default_cfg = None
@classmethod @classmethod
def setup_user_cfg_cls(cls): def setup_user_cfg_cls(cls):
from ...utils.enter_config import UserConfig from ...utils.enter_config import UserConfig
class NodePredUserConfig(UserConfig): class NodePredUserConfig(UserConfig):
eval_device: DeviceEnum = Field("cpu") eval_device: DeviceEnum = Field("cpu")
data: DataFactory.filter("nodepred-ns").get_pydantic_config() = Field(..., discriminator="name") data: DataFactory.filter(
model : NodeModelFactory.filter(lambda cls: hasattr(cls, "forward_block")).get_pydantic_model_config() = Field(..., discriminator="name") "nodepred-ns"
).get_pydantic_config() = Field(..., discriminator="name")
model: NodeModelFactory.filter(
lambda cls: hasattr(cls, "forward_block")
).get_pydantic_model_config() = Field(..., discriminator="name")
general_pipeline: NodepredNSPipelineCfg general_pipeline: NodepredNSPipelineCfg
cls.user_cfg_cls = NodePredUserConfig cls.user_cfg_cls = NodePredUserConfig
...@@ -83,10 +93,15 @@ class NodepredNsPipeline(PipelineBase): ...@@ -83,10 +93,15 @@ class NodepredNsPipeline(PipelineBase):
def get_cfg_func(self): def get_cfg_func(self):
def config( def config(
data: DataFactory.filter("nodepred-ns").get_dataset_enum() = typer.Option(..., help="input data name"), data: DataFactory.filter(
"nodepred-ns"
).get_dataset_enum() = typer.Option(..., help="input data name"),
cfg: Optional[str] = typer.Option( cfg: Optional[str] = typer.Option(
None, help="output configuration path"), None, help="output configuration path"
model: NodeModelFactory.filter(lambda cls: hasattr(cls, "forward_block")).get_model_enum() = typer.Option(..., help="Model name"), ),
model: NodeModelFactory.filter(
lambda cls: hasattr(cls, "forward_block")
).get_model_enum() = typer.Option(..., help="Model name"),
): ):
self.__class__.setup_user_cfg_cls() self.__class__.setup_user_cfg_cls()
generated_cfg = { generated_cfg = {
...@@ -95,14 +110,14 @@ class NodepredNsPipeline(PipelineBase): ...@@ -95,14 +110,14 @@ class NodepredNsPipeline(PipelineBase):
"device": "cpu", "device": "cpu",
"data": {"name": data.name}, "data": {"name": data.name},
"model": {"name": model.value}, "model": {"name": model.value},
"general_pipeline": {"sampler":{"name": "neighbor"}} "general_pipeline": {"sampler": {"name": "neighbor"}},
} }
output_cfg = self.user_cfg_cls(**generated_cfg).dict() output_cfg = self.user_cfg_cls(**generated_cfg).dict()
output_cfg = deep_convert_dict(output_cfg) output_cfg = deep_convert_dict(output_cfg)
comment_dict = { comment_dict = {
"device": "Torch device name, e.g., cpu or cuda or cuda:0", "device": "Torch device name, e.g., cpu or cuda or cuda:0",
"data": { "data": {
"split_ratio": 'Ratio to generate split masks, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset' "split_ratio": "Ratio to generate split masks, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset"
}, },
"general_pipeline": pipeline_comments, "general_pipeline": pipeline_comments,
"model": NodeModelFactory.get_constructor_doc_dict(model.value), "model": NodeModelFactory.get_constructor_doc_dict(model.value),
...@@ -111,14 +126,25 @@ class NodepredNsPipeline(PipelineBase): ...@@ -111,14 +126,25 @@ class NodepredNsPipeline(PipelineBase):
# truncate length fan_out to be the same as num_layers in model # truncate length fan_out to be the same as num_layers in model
if "num_layers" in comment_dict["model"]: if "num_layers" in comment_dict["model"]:
comment_dict['general_pipeline']["sampler"]["fan_out"] = [5,10,15,15,15][:int(comment_dict['model']["num_layers"])] comment_dict["general_pipeline"]["sampler"]["fan_out"] = [
5,
10,
15,
15,
15,
][: int(comment_dict["model"]["num_layers"])]
if cfg is None: if cfg is None:
cfg = "_".join(["nodepred-ns", data.value, model.value]) + ".yaml" cfg = (
"_".join(["nodepred-ns", data.value, model.value]) + ".yaml"
)
yaml = ruamel.yaml.YAML() yaml = ruamel.yaml.YAML()
yaml.dump(comment_dict, Path(cfg).open("w")) yaml.dump(comment_dict, Path(cfg).open("w"))
print("Configuration file is generated at {}".format( print(
Path(cfg).absolute())) "Configuration file is generated at {}".format(
Path(cfg).absolute()
)
)
return config return config
...@@ -129,20 +155,27 @@ class NodepredNsPipeline(PipelineBase): ...@@ -129,20 +155,27 @@ class NodepredNsPipeline(PipelineBase):
with open(template_filename, "r") as f: with open(template_filename, "r") as f:
template = Template(f.read()) template = Template(f.read())
pipeline_cfg = NodepredNSPipelineCfg( pipeline_cfg = NodepredNSPipelineCfg(
**user_cfg_dict["general_pipeline"]) **user_cfg_dict["general_pipeline"]
)
if "num_layers" in user_cfg_dict["model"]: if "num_layers" in user_cfg_dict["model"]:
assert user_cfg_dict["model"]["num_layers"] == len(user_cfg_dict["general_pipeline"]["sampler"]["fan_out"]), \ assert user_cfg_dict["model"]["num_layers"] == len(
"The num_layers in model config should be the same as the length of fan_out in sampler. For example, if num_layers is 1, the fan_out cannot be [5, 10]" user_cfg_dict["general_pipeline"]["sampler"]["fan_out"]
), "The num_layers in model config should be the same as the length of fan_out in sampler. For example, if num_layers is 1, the fan_out cannot be [5, 10]"
render_cfg = copy.deepcopy(user_cfg_dict) render_cfg = copy.deepcopy(user_cfg_dict)
model_code = NodeModelFactory.get_source_code( model_code = NodeModelFactory.get_source_code(
user_cfg_dict["model"]["name"]) user_cfg_dict["model"]["name"]
)
render_cfg["model_code"] = model_code render_cfg["model_code"] = model_code
render_cfg["model_class_name"] = NodeModelFactory.get_model_class_name( render_cfg["model_class_name"] = NodeModelFactory.get_model_class_name(
user_cfg_dict["model"]["name"]) user_cfg_dict["model"]["name"]
render_cfg.update(DataFactory.get_generated_code_dict( )
user_cfg_dict["data"]["name"], '**cfg["data"]')) render_cfg.update(
DataFactory.get_generated_code_dict(
user_cfg_dict["data"]["name"], '**cfg["data"]'
)
)
generated_user_cfg = copy.deepcopy(user_cfg_dict) generated_user_cfg = copy.deepcopy(user_cfg_dict)
if "split_ratio" in generated_user_cfg["data"]: if "split_ratio" in generated_user_cfg["data"]:
...@@ -150,12 +183,16 @@ class NodepredNsPipeline(PipelineBase): ...@@ -150,12 +183,16 @@ class NodepredNsPipeline(PipelineBase):
generated_user_cfg["data_name"] = generated_user_cfg["data"].pop("name") generated_user_cfg["data_name"] = generated_user_cfg["data"].pop("name")
generated_user_cfg.pop("pipeline_name") generated_user_cfg.pop("pipeline_name")
generated_user_cfg.pop("pipeline_mode") generated_user_cfg.pop("pipeline_mode")
generated_user_cfg["model_name"] = generated_user_cfg["model"].pop("name") generated_user_cfg["model_name"] = generated_user_cfg["model"].pop(
generated_user_cfg['general_pipeline']["optimizer"].pop("name") "name"
)
generated_user_cfg["general_pipeline"]["optimizer"].pop("name")
if user_cfg_dict["data"].get("split_ratio", None) is not None: if user_cfg_dict["data"].get("split_ratio", None) is not None:
render_cfg["data_initialize_code"] = "{}, split_ratio={}".format(render_cfg["data_initialize_code"], user_cfg_dict["data"]["split_ratio"]) render_cfg["data_initialize_code"] = "{}, split_ratio={}".format(
render_cfg["data_initialize_code"],
user_cfg_dict["data"]["split_ratio"],
)
render_cfg["user_cfg_str"] = f"cfg = {str(generated_user_cfg)}" render_cfg["user_cfg_str"] = f"cfg = {str(generated_user_cfg)}"
render_cfg["user_cfg"] = user_cfg_dict render_cfg["user_cfg"] = user_cfg_dict
......
import copy
import enum import enum
from enum import Enum, IntEnum
from typing import Optional from typing import Optional
from jinja2 import Template from jinja2 import Template
from enum import Enum, IntEnum from pydantic import (
import copy BaseModel as PydanticBaseModel,
from pydantic import create_model, BaseModel as PydanticBaseModel, Field, create_model create_model,
create_model,
Field,
)
class DeviceEnum(str, Enum): class DeviceEnum(str, Enum):
cpu = "cpu" cpu = "cpu"
cuda = "cuda" cuda = "cuda"
class DGLBaseModel(PydanticBaseModel): class DGLBaseModel(PydanticBaseModel):
class Config: class Config:
extra = "allow" extra = "allow"
...@@ -27,14 +34,16 @@ def get_literal_value(type_): ...@@ -27,14 +34,16 @@ def get_literal_value(type_):
name = type_.__args__[0] name = type_.__args__[0]
return name return name
def extract_name(union_type): def extract_name(union_type):
name_dict = {} name_dict = {}
for t in union_type.__args__: for t in union_type.__args__:
type_ = t.__fields__['name'].type_ type_ = t.__fields__["name"].type_
name = get_literal_value(type_) name = get_literal_value(type_)
name_dict[name] = name name_dict[name] = name
return enum.Enum("Choice", name_dict) return enum.Enum("Choice", name_dict)
class EarlyStopConfig(DGLBaseModel): class EarlyStopConfig(DGLBaseModel):
patience: int = 20 patience: int = 20
checkpoint_path: str = "checkpoint.pth" checkpoint_path: str = "checkpoint.pth"
import torch import torch
class EarlyStopping: class EarlyStopping:
def __init__(self, def __init__(
patience: int = -1, self, patience: int = -1, checkpoint_path: str = "checkpoint.pth"
checkpoint_path: str = 'checkpoint.pth'): ):
self.patience = patience self.patience = patience
self.checkpoint_path = checkpoint_path self.checkpoint_path = checkpoint_path
self.counter = 0 self.counter = 0
...@@ -17,7 +18,9 @@ class EarlyStopping: ...@@ -17,7 +18,9 @@ class EarlyStopping:
self.save_checkpoint(model) self.save_checkpoint(model)
elif score < self.best_score: elif score < self.best_score:
self.counter += 1 self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}') print(
f"EarlyStopping counter: {self.counter} out of {self.patience}"
)
if self.counter >= self.patience: if self.counter >= self.patience:
self.early_stop = True self.early_stop = True
else: else:
...@@ -27,7 +30,7 @@ class EarlyStopping: ...@@ -27,7 +30,7 @@ class EarlyStopping:
return self.early_stop return self.early_stop
def save_checkpoint(self, model): def save_checkpoint(self, model):
'''Save model when validation loss decreases.''' """Save model when validation loss decreases."""
torch.save(model.state_dict(), self.checkpoint_path) torch.save(model.state_dict(), self.checkpoint_path)
def load_checkpoint(self, model): def load_checkpoint(self, model):
......
import copy
from enum import Enum, IntEnum
from typing import Optional from typing import Optional
import yaml
import jinja2 import jinja2
import yaml
from jinja2 import Template from jinja2 import Template
from enum import Enum, IntEnum from pydantic import BaseModel as PydanticBaseModel, create_model, Field
import copy
from pydantic import create_model, BaseModel as PydanticBaseModel, Field
# from ..pipeline import nodepred, nodepred_sample
from .factory import ModelFactory, PipelineFactory, DataFactory
from .base_model import DGLBaseModel
from .base_model import DGLBaseModel
# from ..pipeline import nodepred, nodepred_sample
from .factory import DataFactory, ModelFactory, PipelineFactory
class PipelineConfig(DGLBaseModel): class PipelineConfig(DGLBaseModel):
...@@ -22,6 +21,7 @@ class PipelineConfig(DGLBaseModel): ...@@ -22,6 +21,7 @@ class PipelineConfig(DGLBaseModel):
optimizer: dict = {"name": "Adam", "lr": 0.005} optimizer: dict = {"name": "Adam", "lr": 0.005}
loss: str = "CrossEntropyLoss" loss: str = "CrossEntropyLoss"
class UserConfig(DGLBaseModel): class UserConfig(DGLBaseModel):
version: Optional[str] = "0.0.2" version: Optional[str] = "0.0.2"
pipeline_name: PipelineFactory.get_pipeline_enum() pipeline_name: PipelineFactory.get_pipeline_enum()
......
import enum import enum
import inspect
import logging import logging
from typing import Callable, Dict, Union, List, Tuple, Optional
from typing_extensions import Literal
from pathlib import Path
from abc import ABC, abstractmethod, abstractstaticmethod from abc import ABC, abstractmethod, abstractstaticmethod
from .base_model import DGLBaseModel from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
import yaml import yaml
from pydantic import create_model_from_typeddict, create_model, Field
from dgl.dataloading.negative_sampler import GlobalUniform, PerSourceUniform from dgl.dataloading.negative_sampler import GlobalUniform, PerSourceUniform
import inspect
from numpydoc import docscrape from numpydoc import docscrape
from pydantic import create_model, create_model_from_typeddict, Field
from typing_extensions import Literal
from .base_model import DGLBaseModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_PIPELINE = ["nodepred", "nodepred-ns", "linkpred", "graphpred"] ALL_PIPELINE = ["nodepred", "nodepred-ns", "linkpred", "graphpred"]
class PipelineBase(ABC):
class PipelineBase(ABC):
@abstractmethod @abstractmethod
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
...@@ -34,23 +37,24 @@ class PipelineBase(ABC): ...@@ -34,23 +37,24 @@ class PipelineBase(ABC):
class DataFactoryClass: class DataFactoryClass:
def __init__(self): def __init__(self):
self.registry = {} self.registry = {}
self.pipeline_name = None self.pipeline_name = None
self.pipeline_allowed = {} self.pipeline_allowed = {}
def register(self, def register(
self,
name: str, name: str,
import_code: str, import_code: str,
class_name: str, class_name: str,
allowed_pipeline: List[str], allowed_pipeline: List[str],
extra_args={}): extra_args={},
):
self.registry[name] = { self.registry[name] = {
"name": name, "name": name,
"import_code": import_code, "import_code": import_code,
"class_name": class_name, "class_name": class_name,
"extra_args": extra_args "extra_args": extra_args,
} }
for pipeline in allowed_pipeline: for pipeline in allowed_pipeline:
if pipeline in self.pipeline_allowed: if pipeline in self.pipeline_allowed:
...@@ -61,7 +65,8 @@ class DataFactoryClass: ...@@ -61,7 +65,8 @@ class DataFactoryClass:
def get_dataset_enum(self): def get_dataset_enum(self):
enum_class = enum.Enum( enum_class = enum.Enum(
"DatasetName", {v["name"]: k for k, v in self.registry.items()}) "DatasetName", {v["name"]: k for k, v in self.registry.items()}
)
return enum_class return enum_class
def get_dataset_classname(self, name): def get_dataset_classname(self, name):
...@@ -84,8 +89,13 @@ class DataFactoryClass: ...@@ -84,8 +89,13 @@ class DataFactoryClass:
if "name" in type_annotation_dict: if "name" in type_annotation_dict:
del type_annotation_dict["name"] del type_annotation_dict["name"]
base = self.get_base_class(dataset_name, self.pipeline_name) base = self.get_base_class(dataset_name, self.pipeline_name)
dataset_list.append(create_model( dataset_list.append(
f'{dataset_name}Config', **type_annotation_dict, __base__=base)) create_model(
f"{dataset_name}Config",
**type_annotation_dict,
__base__=base,
)
)
output = dataset_list[0] output = dataset_list[0]
for d in dataset_list[1:]: for d in dataset_list[1:]:
...@@ -116,7 +126,9 @@ class DataFactoryClass: ...@@ -116,7 +126,9 @@ class DataFactoryClass:
def filter(self, pipeline_name): def filter(self, pipeline_name):
allowed_name = self.pipeline_allowed[pipeline_name] allowed_name = self.pipeline_allowed[pipeline_name]
new_registry = {k: v for k,v in self.registry.items() if k in allowed_name} new_registry = {
k: v for k, v in self.registry.items() if k in allowed_name
}
d = DataFactoryClass() d = DataFactoryClass()
d.registry = new_registry d.registry = new_registry
d.pipeline_name = pipeline_name d.pipeline_name = pipeline_name
...@@ -125,18 +137,20 @@ class DataFactoryClass: ...@@ -125,18 +137,20 @@ class DataFactoryClass:
@staticmethod @staticmethod
def get_base_class(dataset_name, pipeline_name): def get_base_class(dataset_name, pipeline_name):
if pipeline_name == "linkpred": if pipeline_name == "linkpred":
class EdgeBase(DGLBaseModel): class EdgeBase(DGLBaseModel):
name: Literal[dataset_name] name: Literal[dataset_name]
split_ratio: Optional[Tuple[float, float, float]] = None split_ratio: Optional[Tuple[float, float, float]] = None
neg_ratio: Optional[int] = None neg_ratio: Optional[int] = None
return EdgeBase return EdgeBase
else: else:
class NodeBase(DGLBaseModel): class NodeBase(DGLBaseModel):
name: Literal[dataset_name] name: Literal[dataset_name]
split_ratio: Optional[Tuple[float, float, float]] = None split_ratio: Optional[Tuple[float, float, float]] = None
return NodeBase
return NodeBase
DataFactory = DataFactoryClass() DataFactory = DataFactoryClass()
...@@ -145,83 +159,96 @@ DataFactory.register( ...@@ -145,83 +159,96 @@ DataFactory.register(
"cora", "cora",
import_code="from dgl.data import CoraGraphDataset", import_code="from dgl.data import CoraGraphDataset",
class_name="CoraGraphDataset()", class_name="CoraGraphDataset()",
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"]) allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"],
)
DataFactory.register( DataFactory.register(
"citeseer", "citeseer",
import_code="from dgl.data import CiteseerGraphDataset", import_code="from dgl.data import CiteseerGraphDataset",
class_name="CiteseerGraphDataset()", class_name="CiteseerGraphDataset()",
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"]) allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"],
)
DataFactory.register( DataFactory.register(
"pubmed", "pubmed",
import_code="from dgl.data import PubmedGraphDataset", import_code="from dgl.data import PubmedGraphDataset",
class_name="PubmedGraphDataset()", class_name="PubmedGraphDataset()",
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"]) allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"],
)
DataFactory.register( DataFactory.register(
"csv", "csv",
import_code="from dgl.data import CSVDataset", import_code="from dgl.data import CSVDataset",
extra_args={"data_path": "./"}, extra_args={"data_path": "./"},
class_name="CSVDataset({})", class_name="CSVDataset({})",
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred", "graphpred"]) allowed_pipeline=["nodepred", "nodepred-ns", "linkpred", "graphpred"],
)
DataFactory.register( DataFactory.register(
"reddit", "reddit",
import_code="from dgl.data import RedditDataset", import_code="from dgl.data import RedditDataset",
class_name="RedditDataset()", class_name="RedditDataset()",
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"]) allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"],
)
DataFactory.register( DataFactory.register(
"co-buy-computer", "co-buy-computer",
import_code="from dgl.data import AmazonCoBuyComputerDataset", import_code="from dgl.data import AmazonCoBuyComputerDataset",
class_name="AmazonCoBuyComputerDataset()", class_name="AmazonCoBuyComputerDataset()",
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"]) allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"],
)
DataFactory.register( DataFactory.register(
"ogbn-arxiv", "ogbn-arxiv",
import_code="from ogb.nodeproppred import DglNodePropPredDataset", import_code="from ogb.nodeproppred import DglNodePropPredDataset",
extra_args={}, extra_args={},
class_name="DglNodePropPredDataset('ogbn-arxiv')", class_name="DglNodePropPredDataset('ogbn-arxiv')",
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"]) allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"],
)
DataFactory.register( DataFactory.register(
"ogbn-products", "ogbn-products",
import_code="from ogb.nodeproppred import DglNodePropPredDataset", import_code="from ogb.nodeproppred import DglNodePropPredDataset",
extra_args={}, extra_args={},
class_name="DglNodePropPredDataset('ogbn-products')", class_name="DglNodePropPredDataset('ogbn-products')",
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"]) allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"],
)
DataFactory.register( DataFactory.register(
"ogbl-collab", "ogbl-collab",
import_code="from ogb.linkproppred import DglLinkPropPredDataset", import_code="from ogb.linkproppred import DglLinkPropPredDataset",
extra_args={}, extra_args={},
class_name="DglLinkPropPredDataset('ogbl-collab')", class_name="DglLinkPropPredDataset('ogbl-collab')",
allowed_pipeline=["linkpred"]) allowed_pipeline=["linkpred"],
)
DataFactory.register( DataFactory.register(
"ogbl-citation2", "ogbl-citation2",
import_code="from ogb.linkproppred import DglLinkPropPredDataset", import_code="from ogb.linkproppred import DglLinkPropPredDataset",
extra_args={}, extra_args={},
class_name="DglLinkPropPredDataset('ogbl-citation2')", class_name="DglLinkPropPredDataset('ogbl-citation2')",
allowed_pipeline=["linkpred"]) allowed_pipeline=["linkpred"],
)
DataFactory.register( DataFactory.register(
"ogbg-molhiv", "ogbg-molhiv",
import_code="from ogb.graphproppred import DglGraphPropPredDataset", import_code="from ogb.graphproppred import DglGraphPropPredDataset",
extra_args={}, extra_args={},
class_name="DglGraphPropPredDataset(name='ogbg-molhiv')", class_name="DglGraphPropPredDataset(name='ogbg-molhiv')",
allowed_pipeline=["graphpred"]) allowed_pipeline=["graphpred"],
)
DataFactory.register( DataFactory.register(
"ogbg-molpcba", "ogbg-molpcba",
import_code="from ogb.graphproppred import DglGraphPropPredDataset", import_code="from ogb.graphproppred import DglGraphPropPredDataset",
extra_args={}, extra_args={},
class_name="DglGraphPropPredDataset(name='ogbg-molpcba')", class_name="DglGraphPropPredDataset(name='ogbg-molpcba')",
allowed_pipeline=["graphpred"]) allowed_pipeline=["graphpred"],
)
class PipelineFactory: class PipelineFactory:
""" The factory class for creating executors""" """The factory class for creating executors"""
registry: Dict[str, PipelineBase] = {} registry: Dict[str, PipelineBase] = {}
default_config_registry = {} default_config_registry = {}
...@@ -229,11 +256,11 @@ class PipelineFactory: ...@@ -229,11 +256,11 @@ class PipelineFactory:
@classmethod @classmethod
def register(cls, name: str) -> Callable: def register(cls, name: str) -> Callable:
def inner_wrapper(wrapped_class) -> Callable: def inner_wrapper(wrapped_class) -> Callable:
if name in cls.registry: if name in cls.registry:
logger.warning( logger.warning(
'Executor %s already exists. Will replace it', name) "Executor %s already exists. Will replace it", name
)
cls.registry[name] = wrapped_class() cls.registry[name] = wrapped_class()
return wrapped_class return wrapped_class
...@@ -241,19 +268,23 @@ class PipelineFactory: ...@@ -241,19 +268,23 @@ class PipelineFactory:
@classmethod @classmethod
def register_default_config_generator(cls, name: str) -> Callable: def register_default_config_generator(cls, name: str) -> Callable:
def inner_wrapper(wrapped_class) -> Callable: def inner_wrapper(wrapped_class) -> Callable:
if name in cls.registry: if name in cls.registry:
logger.warning( logger.warning(
'Executor %s already exists. Will replace it', name) "Executor %s already exists. Will replace it", name
)
cls.default_config_registry[name] = wrapped_class cls.default_config_registry[name] = wrapped_class
return wrapped_class return wrapped_class
return inner_wrapper return inner_wrapper
@classmethod @classmethod
def call_default_config_generator(cls, generator_name, model_name, dataset_name): def call_default_config_generator(
return cls.default_config_registry[generator_name](model_name, dataset_name) cls, generator_name, model_name, dataset_name
):
return cls.default_config_registry[generator_name](
model_name, dataset_name
)
@classmethod @classmethod
def call_generator(cls, generator_name, cfg): def call_generator(cls, generator_name, cfg):
...@@ -262,9 +293,11 @@ class PipelineFactory: ...@@ -262,9 +293,11 @@ class PipelineFactory:
@classmethod @classmethod
def get_pipeline_enum(cls): def get_pipeline_enum(cls):
enum_class = enum.Enum( enum_class = enum.Enum(
"PipelineName", {k: k for k, v in cls.registry.items()}) "PipelineName", {k: k for k, v in cls.registry.items()}
)
return enum_class return enum_class
class ApplyPipelineFactory: class ApplyPipelineFactory:
"""The factory class for creating executors for inference""" """The factory class for creating executors for inference"""
...@@ -273,38 +306,41 @@ class ApplyPipelineFactory: ...@@ -273,38 +306,41 @@ class ApplyPipelineFactory:
@classmethod @classmethod
def register(cls, name: str) -> Callable: def register(cls, name: str) -> Callable:
def inner_wrapper(wrapped_class) -> Callable: def inner_wrapper(wrapped_class) -> Callable:
if name in cls.registry: if name in cls.registry:
logger.warning( logger.warning(
'Executor %s already exists. Will replace it', name) "Executor %s already exists. Will replace it", name
)
cls.registry[name] = wrapped_class() cls.registry[name] = wrapped_class()
return wrapped_class return wrapped_class
return inner_wrapper return inner_wrapper
model_dir = Path(__file__).parent.parent / "model" model_dir = Path(__file__).parent.parent / "model"
class ModelFactory: class ModelFactory:
""" The factory class for creating executors""" """The factory class for creating executors"""
def __init__(self): def __init__(self):
self.registry = {} self.registry = {}
self.code_registry = {} self.code_registry = {}
""" Internal registry for available executors """ """ Internal registry for available executors """
def get_model_enum(self): def get_model_enum(self):
enum_class = enum.Enum( enum_class = enum.Enum(
"ModelName", {k: k for k, v in self.registry.items()}) "ModelName", {k: k for k, v in self.registry.items()}
)
return enum_class return enum_class
def register(self, model_name: str) -> Callable: def register(self, model_name: str) -> Callable:
def inner_wrapper(wrapped_class) -> Callable: def inner_wrapper(wrapped_class) -> Callable:
if model_name in self.registry: if model_name in self.registry:
logger.warning( logger.warning(
'Executor %s already exists. Will replace it', model_name) "Executor %s already exists. Will replace it", model_name
)
self.registry[model_name] = wrapped_class self.registry[model_name] = wrapped_class
# code_filename = model_dir / filename # code_filename = model_dir / filename
code_filename = Path(inspect.getfile(wrapped_class)) code_filename = Path(inspect.getfile(wrapped_class))
...@@ -328,14 +364,19 @@ class ModelFactory: ...@@ -328,14 +364,19 @@ class ModelFactory:
arg_dict = self.get_constructor_default_args(model_name) arg_dict = self.get_constructor_default_args(model_name)
type_annotation_dict = {} type_annotation_dict = {}
# type_annotation_dict["name"] = Literal[""] # type_annotation_dict["name"] = Literal[""]
exempt_keys = ['self', 'in_size', 'out_size', 'data_info'] exempt_keys = ["self", "in_size", "out_size", "data_info"]
for k, param in arg_dict.items(): for k, param in arg_dict.items():
if k not in exempt_keys: if k not in exempt_keys:
type_annotation_dict[k] = arg_dict[k] type_annotation_dict[k] = arg_dict[k]
class Base(DGLBaseModel): class Base(DGLBaseModel):
name: Literal[model_name] name: Literal[model_name]
return create_model(f'{model_name.upper()}ModelConfig', **type_annotation_dict, __base__=Base)
return create_model(
f"{model_name.upper()}ModelConfig",
**type_annotation_dict,
__base__=Base,
)
def get_constructor_doc_dict(self, name): def get_constructor_doc_dict(self, name):
model_class = self.registry[name] model_class = self.registry[name]
...@@ -375,22 +416,23 @@ class ModelFactory: ...@@ -375,22 +416,23 @@ class ModelFactory:
class SamplerFactory: class SamplerFactory:
""" The factory class for creating executors""" """The factory class for creating executors"""
def __init__(self): def __init__(self):
self.registry = {} self.registry = {}
def get_model_enum(self): def get_model_enum(self):
enum_class = enum.Enum( enum_class = enum.Enum(
"NegativeSamplerName", {k: k for k, v in self.registry.items()}) "NegativeSamplerName", {k: k for k, v in self.registry.items()}
)
return enum_class return enum_class
def register(self, sampler_name: str) -> Callable: def register(self, sampler_name: str) -> Callable:
def inner_wrapper(wrapped_class) -> Callable: def inner_wrapper(wrapped_class) -> Callable:
if sampler_name in self.registry: if sampler_name in self.registry:
logger.warning( logger.warning(
'Sampler %s already exists. Will replace it', sampler_name) "Sampler %s already exists. Will replace it", sampler_name
)
self.registry[sampler_name] = wrapped_class self.registry[sampler_name] = wrapped_class
return wrapped_class return wrapped_class
...@@ -408,17 +450,22 @@ class SamplerFactory: ...@@ -408,17 +450,22 @@ class SamplerFactory:
arg_dict = self.get_constructor_default_args(sampler_name) arg_dict = self.get_constructor_default_args(sampler_name)
type_annotation_dict = {} type_annotation_dict = {}
# type_annotation_dict["name"] = Literal[""] # type_annotation_dict["name"] = Literal[""]
exempt_keys = ['self', 'in_size', 'out_size', 'redundancy'] exempt_keys = ["self", "in_size", "out_size", "redundancy"]
for k, param in arg_dict.items(): for k, param in arg_dict.items():
if k not in exempt_keys or param is None: if k not in exempt_keys or param is None:
if k == 'k' or k == 'redundancy': if k == "k" or k == "redundancy":
type_annotation_dict[k] = 3 type_annotation_dict[k] = 3
else: else:
type_annotation_dict[k] = arg_dict[k] type_annotation_dict[k] = arg_dict[k]
class Base(DGLBaseModel): class Base(DGLBaseModel):
name: Literal[sampler_name] name: Literal[sampler_name]
return create_model(f'{sampler_name.upper()}SamplerConfig', **type_annotation_dict, __base__=Base)
return create_model(
f"{sampler_name.upper()}SamplerConfig",
**type_annotation_dict,
__base__=Base,
)
def get_pydantic_model_config(self): def get_pydantic_model_config(self):
model_list = [] model_list = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment