Unverified Commit dce89919 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Auto-reformat multiple python folders. (#5325)



* auto-reformat

* lintrunner

---------
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent ab812179
from .gcn import GCN
from ...utils.factory import NodeModelFactory
from .gat import GAT
from .gcn import GCN
from .gin import GIN
from .sage import GraphSAGE
from .sgc import SGC
from .gin import GIN
from ...utils.factory import NodeModelFactory
NodeModelFactory.register("gcn")(GCN)
NodeModelFactory.register("gat")(GAT)
......
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import GATConv
from dgl.base import dgl_warning
from dgl.nn import GATConv
class GAT(nn.Module):
def __init__(self,
def __init__(
self,
data_info: dict,
embed_size: int = -1,
num_layers: int = 2,
......@@ -17,7 +19,8 @@ class GAT(nn.Module):
feat_drop: float = 0.6,
attn_drop: float = 0.6,
negative_slope: float = 0.2,
residual: bool = False):
residual: bool = False,
):
"""Graph Attention Networks
Parameters
......@@ -57,19 +60,30 @@ class GAT(nn.Module):
in_size = data_info["in_size"]
for i in range(num_layers):
in_hidden = hidden_size*heads[i-1] if i > 0 else in_size
out_hidden = hidden_size if i < num_layers - \
1 else data_info["out_size"]
in_hidden = hidden_size * heads[i - 1] if i > 0 else in_size
out_hidden = (
hidden_size if i < num_layers - 1 else data_info["out_size"]
)
activation = None if i == num_layers - 1 else self.activation
self.gat_layers.append(GATConv(
in_hidden, out_hidden, heads[i],
feat_drop, attn_drop, negative_slope, residual, activation))
self.gat_layers.append(
GATConv(
in_hidden,
out_hidden,
heads[i],
feat_drop,
attn_drop,
negative_slope,
residual,
activation,
)
)
def forward(self, graph, node_feat, edge_feat=None):
if self.embed_size > 0:
dgl_warning(
"The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size.")
"The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size."
)
h = self.embed.weight
else:
h = node_feat
......
import dgl
import torch
import torch.nn as nn
import dgl
from dgl.base import dgl_warning
class GCN(nn.Module):
def __init__(self,
def __init__(
self,
data_info: dict,
embed_size: int = -1,
hidden_size: int = 16,
......@@ -12,7 +14,8 @@ class GCN(nn.Module):
norm: str = "both",
activation: str = "relu",
dropout: float = 0.5,
use_edge_weight: bool = False):
use_edge_weight: bool = False,
):
"""Graph Convolutional Networks
Parameters
......@@ -47,28 +50,36 @@ class GCN(nn.Module):
for i in range(num_layers):
in_hidden = hidden_size if i > 0 else in_size
out_hidden = hidden_size if i < num_layers - 1 else data_info["out_size"]
out_hidden = (
hidden_size if i < num_layers - 1 else data_info["out_size"]
)
self.layers.append(dgl.nn.GraphConv(in_hidden, out_hidden, norm=norm, allow_zero_in_degree=True))
self.layers.append(
dgl.nn.GraphConv(
in_hidden, out_hidden, norm=norm, allow_zero_in_degree=True
)
)
self.dropout = nn.Dropout(p=dropout)
self.act = getattr(torch, activation)
def forward(self, g, node_feat, edge_feat = None):
def forward(self, g, node_feat, edge_feat=None):
if self.embed_size > 0:
dgl_warning("The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size.")
dgl_warning(
"The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size."
)
h = self.embed.weight
else:
h = node_feat
edge_weight = edge_feat if self.use_edge_weight else None
for l, layer in enumerate(self.layers):
h = layer(g, h, edge_weight=edge_weight)
if l != len(self.layers) -1:
if l != len(self.layers) - 1:
h = self.act(h)
h = self.dropout(h)
return h
def forward_block(self, blocks, node_feat, edge_feat = None):
def forward_block(self, blocks, node_feat, edge_feat=None):
h = node_feat
edge_weight = edge_feat if self.use_edge_weight else None
for l, (layer, block) in enumerate(zip(self.layers, blocks)):
......
import torch.nn as nn
from dgl.nn import GINConv
from dgl.base import dgl_warning
from dgl.nn import GINConv
class GIN(nn.Module):
def __init__(self,
def __init__(
self,
data_info: dict,
embed_size: int = -1,
hidden_size=64,
num_layers=3,
aggregator_type='sum'):
aggregator_type="sum",
):
"""Graph Isomophism Networks
Edge feature is ignored in this model.
......@@ -39,9 +41,13 @@ class GIN(nn.Module):
in_size = data_info["in_size"]
for i in range(num_layers):
input_dim = in_size if i == 0 else hidden_size
mlp = nn.Sequential(nn.Linear(input_dim, hidden_size),
nn.BatchNorm1d(hidden_size), nn.ReLU(),
nn.Linear(hidden_size, hidden_size), nn.ReLU())
mlp = nn.Sequential(
nn.Linear(input_dim, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
)
self.conv_list.append(GINConv(mlp, aggregator_type, 1e-5, True))
self.out_mlp = nn.Linear(hidden_size, data_info["out_size"])
......@@ -49,7 +55,8 @@ class GIN(nn.Module):
def forward(self, graph, node_feat, edge_feat=None):
if self.embed_size > 0:
dgl_warning(
"The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size.")
"The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size."
)
h = self.embed.weight
else:
h = node_feat
......
import torch.nn as nn
import dgl
import torch.nn as nn
from dgl.base import dgl_warning
class GraphSAGE(nn.Module):
def __init__(self,
def __init__(
self,
data_info: dict,
embed_size: int = -1,
hidden_size: int = 16,
num_layers: int = 1,
activation: str = "relu",
dropout: float = 0.5,
aggregator_type: str = "gcn"):
aggregator_type: str = "gcn",
):
"""GraphSAGE model
Parameters
......@@ -44,12 +47,18 @@ class GraphSAGE(nn.Module):
for i in range(num_layers):
in_hidden = hidden_size if i > 0 else in_size
out_hidden = hidden_size if i < num_layers - 1 else data_info["out_size"]
self.layers.append(dgl.nn.SAGEConv(in_hidden, out_hidden, aggregator_type))
out_hidden = (
hidden_size if i < num_layers - 1 else data_info["out_size"]
)
self.layers.append(
dgl.nn.SAGEConv(in_hidden, out_hidden, aggregator_type)
)
def forward(self, graph, node_feat, edge_feat = None):
def forward(self, graph, node_feat, edge_feat=None):
if self.embed_size > 0:
dgl_warning("The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size.")
dgl_warning(
"The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size."
)
h = self.embed.weight
else:
h = node_feat
......@@ -61,7 +70,7 @@ class GraphSAGE(nn.Module):
h = self.dropout(h)
return h
def forward_block(self, blocks, node_feat, edge_feat = None):
def forward_block(self, blocks, node_feat, edge_feat=None):
h = node_feat
for l, (layer, block) in enumerate(zip(self.layers, blocks)):
h = layer(block, h, edge_feat)
......
import torch.nn as nn
import dgl.function as fn
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SGConv
from dgl.base import dgl_warning
from dgl.nn import SGConv
class SGC(nn.Module):
def __init__(self,
data_info: dict,
embed_size: int = -1,
bias=True, k=2):
""" Simplifying Graph Convolutional Networks
def __init__(self, data_info: dict, embed_size: int = -1, bias=True, k=2):
"""Simplifying Graph Convolutional Networks
Edge feature is ignored in this model.
......@@ -34,12 +31,20 @@ class SGC(nn.Module):
in_size = embed_size
else:
in_size = data_info["in_size"]
self.sgc = SGConv(in_size, self.out_size, k=k, cached=True,
bias=bias, norm=self.normalize)
self.sgc = SGConv(
in_size,
self.out_size,
k=k,
cached=True,
bias=bias,
norm=self.normalize,
)
def forward(self, g, node_feat, edge_feat=None):
if self.embed_size > 0:
dgl_warning("The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size.")
dgl_warning(
"The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size."
)
h = self.embed.weight
else:
h = node_feat
......@@ -47,4 +52,4 @@ class SGC(nn.Module):
@staticmethod
def normalize(h):
return (h-h.mean(0))/(h.std(0) + 1e-5)
return (h - h.mean(0)) / (h.std(0) + 1e-5)
from .graphpred import GraphpredPipeline
from .linkpred import LinkpredPipeline
from .nodepred import NodepredPipeline
from .nodepred_sample import NodepredNsPipeline
from .linkpred import LinkpredPipeline
from .graphpred import GraphpredPipeline
\ No newline at end of file
from pathlib import Path
from jinja2 import Template
import copy
from pathlib import Path
from typing import Optional
import ruamel.yaml
import typer
from jinja2 import Template
from pydantic import BaseModel, Field
from typing import Optional
from ...utils.factory import PipelineFactory, GraphModelFactory, PipelineBase, DataFactory
from ...utils.factory import (
DataFactory,
GraphModelFactory,
PipelineBase,
PipelineFactory,
)
from ...utils.yaml_dump import deep_convert_dict, merge_comment
import ruamel.yaml
pipeline_comments = {
"num_runs": "Number of experiments to run",
......@@ -14,9 +21,10 @@ pipeline_comments = {
"eval_batch_size": "Graph batch size when evaluating",
"num_workers": "Number of workers for data loading",
"num_epochs": "Number of training epochs",
"save_path": "Directory to save the experiment results"
"save_path": "Directory to save the experiment results",
}
class GraphpredPipelineCfg(BaseModel):
num_runs: int = 1
train_batch_size: int = 32
......@@ -30,20 +38,23 @@ class GraphpredPipelineCfg(BaseModel):
num_epochs: int = 100
save_path: str = "results"
@PipelineFactory.register("graphpred")
class GraphpredPipeline(PipelineBase):
def __init__(self):
self.pipeline = {
"name": "graphpred",
"mode": "train"
}
self.pipeline = {"name": "graphpred", "mode": "train"}
@classmethod
def setup_user_cfg_cls(cls):
from ...utils.enter_config import UserConfig
class GraphPredUserConfig(UserConfig):
data: DataFactory.filter("graphpred").get_pydantic_config() = Field(..., discriminator="name")
model: GraphModelFactory.get_pydantic_model_config() = Field(..., discriminator="name")
data: DataFactory.filter("graphpred").get_pydantic_config() = Field(
..., discriminator="name"
)
model: GraphModelFactory.get_pydantic_model_config() = Field(
..., discriminator="name"
)
general_pipeline: GraphpredPipelineCfg = GraphpredPipelineCfg()
cls.user_cfg_cls = GraphPredUserConfig
......@@ -54,10 +65,15 @@ class GraphpredPipeline(PipelineBase):
def get_cfg_func(self):
def config(
data: DataFactory.filter("graphpred").get_dataset_enum() = typer.Option(..., help="input data name"),
data: DataFactory.filter(
"graphpred"
).get_dataset_enum() = typer.Option(..., help="input data name"),
cfg: Optional[str] = typer.Option(
None, help="output configuration path"),
model: GraphModelFactory.get_model_enum() = typer.Option(..., help="Model name"),
None, help="output configuration path"
),
model: GraphModelFactory.get_model_enum() = typer.Option(
..., help="Model name"
),
):
self.__class__.setup_user_cfg_cls()
generated_cfg = {
......@@ -66,17 +82,19 @@ class GraphpredPipeline(PipelineBase):
"device": "cpu",
"data": {"name": data.name},
"model": {"name": model.value},
"general_pipeline": {}
"general_pipeline": {},
}
output_cfg = self.user_cfg_cls(**generated_cfg).dict()
output_cfg = deep_convert_dict(output_cfg)
comment_dict = {
"device": "Torch device name, e.g., cpu or cuda or cuda:0",
"data": {
"split_ratio": 'Ratio to generate data split, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset'
"split_ratio": "Ratio to generate data split, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset"
},
"general_pipeline": pipeline_comments,
"model": GraphModelFactory.get_constructor_doc_dict(model.value)
"model": GraphModelFactory.get_constructor_doc_dict(
model.value
),
}
comment_dict = merge_comment(output_cfg, comment_dict)
......@@ -84,7 +102,11 @@ class GraphpredPipeline(PipelineBase):
if cfg is None:
cfg = "_".join(["graphpred", data.value, model.value]) + ".yaml"
yaml.dump(comment_dict, Path(cfg).open("w"))
print("Configuration file is generated at {}".format(Path(cfg).absolute()))
print(
"Configuration file is generated at {}".format(
Path(cfg).absolute()
)
)
return config
......@@ -97,11 +119,17 @@ class GraphpredPipeline(PipelineBase):
render_cfg = copy.deepcopy(user_cfg_dict)
model_code = GraphModelFactory.get_source_code(
user_cfg_dict["model"]["name"])
user_cfg_dict["model"]["name"]
)
render_cfg["model_code"] = model_code
render_cfg["model_class_name"] = GraphModelFactory.get_model_class_name(
user_cfg_dict["model"]["name"])
render_cfg.update(DataFactory.get_generated_code_dict(user_cfg_dict["data"]["name"], '**cfg["data"]'))
user_cfg_dict["model"]["name"]
)
render_cfg.update(
DataFactory.get_generated_code_dict(
user_cfg_dict["data"]["name"], '**cfg["data"]'
)
)
generated_user_cfg = copy.deepcopy(user_cfg_dict)
if "split_ratio" in generated_user_cfg["data"]:
......@@ -109,7 +137,9 @@ class GraphpredPipeline(PipelineBase):
generated_user_cfg["data_name"] = generated_user_cfg["data"].pop("name")
generated_user_cfg.pop("pipeline_name")
generated_user_cfg.pop("pipeline_mode")
generated_user_cfg["model_name"] = generated_user_cfg["model"].pop("name")
generated_user_cfg["model_name"] = generated_user_cfg["model"].pop(
"name"
)
generated_user_cfg["general_pipeline"]["optimizer"].pop("name")
generated_user_cfg["general_pipeline"]["lr_scheduler"].pop("name")
......@@ -118,7 +148,10 @@ class GraphpredPipeline(PipelineBase):
generated_train_cfg["lr_scheduler"].pop("name")
if user_cfg_dict["data"].get("split_ratio", None) is not None:
render_cfg["data_initialize_code"] = "{}, split_ratio={}".format(render_cfg["data_initialize_code"], user_cfg_dict["data"]["split_ratio"])
render_cfg["data_initialize_code"] = "{}, split_ratio={}".format(
render_cfg["data_initialize_code"],
user_cfg_dict["data"]["split_ratio"],
)
render_cfg["user_cfg_str"] = f"cfg = {str(generated_user_cfg)}"
render_cfg["user_cfg"] = user_cfg_dict
return template.render(**render_cfg)
......
from pathlib import Path
from jinja2 import Template
import copy
import typer
from pydantic import BaseModel, Field
from pathlib import Path
from typing import Optional
import yaml
from ...utils.factory import PipelineFactory, NodeModelFactory, PipelineBase, DataFactory, EdgeModelFactory, NegativeSamplerFactory
from ...utils.base_model import EarlyStopConfig, DeviceEnum
from ...utils.yaml_dump import deep_convert_dict, merge_comment
import ruamel.yaml
import typer
import yaml
from jinja2 import Template
from pydantic import BaseModel, Field
from ruamel.yaml.comments import CommentedMap
from ...utils.base_model import DeviceEnum, EarlyStopConfig
from ...utils.factory import (
DataFactory,
EdgeModelFactory,
NegativeSamplerFactory,
NodeModelFactory,
PipelineBase,
PipelineFactory,
)
from ...utils.yaml_dump import deep_convert_dict, merge_comment
class LinkpredPipelineCfg(BaseModel):
hidden_size: int = 256
eval_batch_size: int = 32769
......@@ -42,23 +52,25 @@ class LinkpredPipeline(PipelineBase):
pipeline_name = "linkpred"
def __init__(self):
self.pipeline = {
"name": "linkpred",
"mode": "train"
}
self.pipeline = {"name": "linkpred", "mode": "train"}
@classmethod
def setup_user_cfg_cls(cls):
from ...utils.enter_config import UserConfig
class LinkPredUserConfig(UserConfig):
data: DataFactory.filter("linkpred").get_pydantic_config() = Field(..., discriminator="name")
node_model: NodeModelFactory.get_pydantic_model_config() = Field(...,
discriminator="name")
edge_model: EdgeModelFactory.get_pydantic_model_config() = Field(...,
discriminator="name")
neg_sampler: NegativeSamplerFactory.get_pydantic_model_config() = Field(...,
discriminator="name")
data: DataFactory.filter("linkpred").get_pydantic_config() = Field(
..., discriminator="name"
)
node_model: NodeModelFactory.get_pydantic_model_config() = Field(
..., discriminator="name"
)
edge_model: EdgeModelFactory.get_pydantic_model_config() = Field(
..., discriminator="name"
)
neg_sampler: NegativeSamplerFactory.get_pydantic_model_config() = (
Field(..., discriminator="name")
)
general_pipeline: LinkpredPipelineCfg = LinkpredPipelineCfg()
cls.user_cfg_cls = LinkPredUserConfig
......@@ -69,15 +81,21 @@ class LinkpredPipeline(PipelineBase):
def get_cfg_func(self):
def config(
data: DataFactory.filter("linkpred").get_dataset_enum() = typer.Option(..., help="input data name"),
data: DataFactory.filter(
"linkpred"
).get_dataset_enum() = typer.Option(..., help="input data name"),
cfg: str = typer.Option(
"cfg.yaml", help="output configuration path"),
node_model: NodeModelFactory.get_model_enum() = typer.Option(...,
help="Model name"),
edge_model: EdgeModelFactory.get_model_enum() = typer.Option(...,
help="Model name"),
"cfg.yaml", help="output configuration path"
),
node_model: NodeModelFactory.get_model_enum() = typer.Option(
..., help="Model name"
),
edge_model: EdgeModelFactory.get_model_enum() = typer.Option(
..., help="Model name"
),
neg_sampler: NegativeSamplerFactory.get_model_enum() = typer.Option(
"persource", help="Negative sampler name"),
"persource", help="Negative sampler name"
),
):
self.__class__.setup_user_cfg_cls()
generated_cfg = {
......@@ -94,21 +112,41 @@ class LinkpredPipeline(PipelineBase):
comment_dict = {
"device": "Torch device name, e.g., cpu or cuda or cuda:0",
"general_pipeline": pipeline_comments,
"node_model": NodeModelFactory.get_constructor_doc_dict(node_model.value),
"edge_model": EdgeModelFactory.get_constructor_doc_dict(edge_model.value),
"neg_sampler": NegativeSamplerFactory.get_constructor_doc_dict(neg_sampler.value),
"node_model": NodeModelFactory.get_constructor_doc_dict(
node_model.value
),
"edge_model": EdgeModelFactory.get_constructor_doc_dict(
edge_model.value
),
"neg_sampler": NegativeSamplerFactory.get_constructor_doc_dict(
neg_sampler.value
),
"data": {
"split_ratio": 'List of float, e.q. [0.8, 0.1, 0.1]. Split ratios for training, validation and test sets. Must sum to one. Leave blank to use builtin split in original dataset',
"neg_ratio": 'Int, e.q. 2. Indicate how much negative samples to be sampled per positive samples. Leave blank to use builtin split in original dataset'
"split_ratio": "List of float, e.q. [0.8, 0.1, 0.1]. Split ratios for training, validation and test sets. Must sum to one. Leave blank to use builtin split in original dataset",
"neg_ratio": "Int, e.q. 2. Indicate how much negative samples to be sampled per positive samples. Leave blank to use builtin split in original dataset",
},
}
comment_dict = merge_comment(output_cfg, comment_dict)
if cfg is None:
cfg = "_".join(["linkpred", data.value, node_model.value, edge_model.value]) + ".yaml"
cfg = (
"_".join(
[
"linkpred",
data.value,
node_model.value,
edge_model.value,
]
)
+ ".yaml"
)
yaml = ruamel.yaml.YAML()
yaml.dump(comment_dict, Path(cfg).open("w"))
print("Configuration file is generated at {}".format(Path(cfg).absolute()))
print(
"Configuration file is generated at {}".format(
Path(cfg).absolute()
)
)
return config
......@@ -123,18 +161,33 @@ class LinkpredPipeline(PipelineBase):
render_cfg = copy.deepcopy(user_cfg_dict)
render_cfg["node_model_code"] = NodeModelFactory.get_source_code(
user_cfg_dict["node_model"]["name"])
user_cfg_dict["node_model"]["name"]
)
render_cfg["edge_model_code"] = EdgeModelFactory.get_source_code(
user_cfg_dict["edge_model"]["name"])
render_cfg["node_model_class_name"] = NodeModelFactory.get_model_class_name(
user_cfg_dict["node_model"]["name"])
render_cfg["edge_model_class_name"] = EdgeModelFactory.get_model_class_name(
user_cfg_dict["edge_model"]["name"])
render_cfg["neg_sampler_name"] = NegativeSamplerFactory.get_model_class_name(
user_cfg_dict["neg_sampler"]["name"])
user_cfg_dict["edge_model"]["name"]
)
render_cfg[
"node_model_class_name"
] = NodeModelFactory.get_model_class_name(
user_cfg_dict["node_model"]["name"]
)
render_cfg[
"edge_model_class_name"
] = EdgeModelFactory.get_model_class_name(
user_cfg_dict["edge_model"]["name"]
)
render_cfg[
"neg_sampler_name"
] = NegativeSamplerFactory.get_model_class_name(
user_cfg_dict["neg_sampler"]["name"]
)
render_cfg["loss"] = user_cfg_dict["general_pipeline"]["loss"]
# update import and initialization code
render_cfg.update(DataFactory.get_generated_code_dict(user_cfg_dict["data"]["name"], '**cfg["data"]'))
render_cfg.update(
DataFactory.get_generated_code_dict(
user_cfg_dict["data"]["name"], '**cfg["data"]'
)
)
generated_user_cfg = copy.deepcopy(user_cfg_dict)
if len(generated_user_cfg["data"]) == 1:
generated_user_cfg.pop("data")
......@@ -150,10 +203,17 @@ class LinkpredPipeline(PipelineBase):
generated_train_cfg = copy.deepcopy(user_cfg_dict["general_pipeline"])
generated_train_cfg["optimizer"].pop("name")
if user_cfg_dict["data"].get("split_ratio", None) is not None:
assert user_cfg_dict["data"].get("neg_ratio", None) is not None, "Please specify both split_ratio and neg_ratio"
render_cfg["data_initialize_code"] = "{}, split_ratio={}, neg_ratio={}".format(render_cfg["data_initialize_code"], user_cfg_dict["data"]["split_ratio"], user_cfg_dict["data"]["neg_ratio"])
assert (
user_cfg_dict["data"].get("neg_ratio", None) is not None
), "Please specify both split_ratio and neg_ratio"
render_cfg[
"data_initialize_code"
] = "{}, split_ratio={}, neg_ratio={}".format(
render_cfg["data_initialize_code"],
user_cfg_dict["data"]["split_ratio"],
user_cfg_dict["data"]["neg_ratio"],
)
generated_user_cfg["data"].pop("split_ratio")
generated_user_cfg["data"].pop("neg_ratio")
......
from pathlib import Path
from jinja2 import Template
import copy
import typer
from pydantic import BaseModel, Field
from pathlib import Path
from typing import Optional
import yaml
from ...utils.factory import PipelineFactory, NodeModelFactory, PipelineBase, DataFactory
from ...utils.base_model import EarlyStopConfig, DeviceEnum
from ...utils.yaml_dump import deep_convert_dict, merge_comment
import ruamel.yaml
import typer
import yaml
from jinja2 import Template
from pydantic import BaseModel, Field
from ruamel.yaml.comments import CommentedMap
from ...utils.base_model import DeviceEnum, EarlyStopConfig
from ...utils.factory import (
DataFactory,
NodeModelFactory,
PipelineBase,
PipelineFactory,
)
from ...utils.yaml_dump import deep_convert_dict, merge_comment
pipeline_comments = {
"num_epochs": "Number of training epochs",
"eval_period": "Interval epochs between evaluations",
"early_stop": {
"patience": "Steps before early stop",
"checkpoint_path": "Early stop checkpoint model file path"
"checkpoint_path": "Early stop checkpoint model file path",
},
"save_path": "Directory to save the experiment results",
"num_runs": "Number of experiments to run",
}
class NodepredPipelineCfg(BaseModel):
early_stop: Optional[EarlyStopConfig] = EarlyStopConfig()
num_epochs: int = 200
......@@ -31,23 +39,26 @@ class NodepredPipelineCfg(BaseModel):
save_path: str = "results"
num_runs: int = 1
@PipelineFactory.register("nodepred")
class NodepredPipeline(PipelineBase):
user_cfg_cls = None
def __init__(self):
self.pipeline = {
"name": "nodepred",
"mode": "train"
}
self.pipeline = {"name": "nodepred", "mode": "train"}
@classmethod
def setup_user_cfg_cls(cls):
from ...utils.enter_config import UserConfig
class NodePredUserConfig(UserConfig):
data: DataFactory.filter("nodepred").get_pydantic_config() = Field(..., discriminator="name")
model : NodeModelFactory.get_pydantic_model_config() = Field(..., discriminator="name")
data: DataFactory.filter("nodepred").get_pydantic_config() = Field(
..., discriminator="name"
)
model: NodeModelFactory.get_pydantic_model_config() = Field(
..., discriminator="name"
)
general_pipeline: NodepredPipelineCfg = NodepredPipelineCfg()
cls.user_cfg_cls = NodePredUserConfig
......@@ -58,10 +69,15 @@ class NodepredPipeline(PipelineBase):
def get_cfg_func(self):
def config(
data: DataFactory.filter("nodepred").get_dataset_enum() = typer.Option(..., help="input data name"),
data: DataFactory.filter(
"nodepred"
).get_dataset_enum() = typer.Option(..., help="input data name"),
cfg: Optional[str] = typer.Option(
None, help="output configuration path"),
model: NodeModelFactory.get_model_enum() = typer.Option(..., help="Model name"),
None, help="output configuration path"
),
model: NodeModelFactory.get_model_enum() = typer.Option(
..., help="Model name"
),
):
self.__class__.setup_user_cfg_cls()
generated_cfg = {
......@@ -70,17 +86,17 @@ class NodepredPipeline(PipelineBase):
"device": "cpu",
"data": {"name": data.name},
"model": {"name": model.value},
"general_pipeline": {}
"general_pipeline": {},
}
output_cfg = self.user_cfg_cls(**generated_cfg).dict()
output_cfg = deep_convert_dict(output_cfg)
comment_dict = {
"device": "Torch device name, e.g., cpu or cuda or cuda:0",
"data": {
"split_ratio": 'Ratio to generate split masks, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset'
"split_ratio": "Ratio to generate split masks, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset"
},
"general_pipeline": pipeline_comments,
"model": NodeModelFactory.get_constructor_doc_dict(model.value)
"model": NodeModelFactory.get_constructor_doc_dict(model.value),
}
comment_dict = merge_comment(output_cfg, comment_dict)
......@@ -88,7 +104,11 @@ class NodepredPipeline(PipelineBase):
if cfg is None:
cfg = "_".join(["nodepred", data.value, model.value]) + ".yaml"
yaml.dump(comment_dict, Path(cfg).open("w"))
print("Configuration file is generated at {}".format(Path(cfg).absolute()))
print(
"Configuration file is generated at {}".format(
Path(cfg).absolute()
)
)
return config
......@@ -103,11 +123,17 @@ class NodepredPipeline(PipelineBase):
render_cfg = copy.deepcopy(user_cfg_dict)
model_code = NodeModelFactory.get_source_code(
user_cfg_dict["model"]["name"])
user_cfg_dict["model"]["name"]
)
render_cfg["model_code"] = model_code
render_cfg["model_class_name"] = NodeModelFactory.get_model_class_name(
user_cfg_dict["model"]["name"])
render_cfg.update(DataFactory.get_generated_code_dict(user_cfg_dict["data"]["name"], '**cfg["data"]'))
user_cfg_dict["model"]["name"]
)
render_cfg.update(
DataFactory.get_generated_code_dict(
user_cfg_dict["data"]["name"], '**cfg["data"]'
)
)
generated_user_cfg = copy.deepcopy(user_cfg_dict)
if "split_ratio" in generated_user_cfg["data"]:
......@@ -115,15 +141,19 @@ class NodepredPipeline(PipelineBase):
generated_user_cfg["data_name"] = generated_user_cfg["data"].pop("name")
generated_user_cfg.pop("pipeline_name")
generated_user_cfg.pop("pipeline_mode")
generated_user_cfg["model_name"] = generated_user_cfg["model"].pop("name")
generated_user_cfg["model_name"] = generated_user_cfg["model"].pop(
"name"
)
generated_user_cfg["general_pipeline"]["optimizer"].pop("name")
generated_train_cfg = copy.deepcopy(user_cfg_dict["general_pipeline"])
generated_train_cfg["optimizer"].pop("name")
if user_cfg_dict["data"].get("split_ratio", None) is not None:
render_cfg["data_initialize_code"] = "{}, split_ratio={}".format(render_cfg["data_initialize_code"], user_cfg_dict["data"]["split_ratio"])
render_cfg["data_initialize_code"] = "{}, split_ratio={}".format(
render_cfg["data_initialize_code"],
user_cfg_dict["data"]["split_ratio"],
)
render_cfg["user_cfg_str"] = f"cfg = {str(generated_user_cfg)}"
render_cfg["user_cfg"] = user_cfg_dict
return template.render(**render_cfg)
......
import copy
from enum import Enum
from pathlib import Path
from typing import Optional, List, Union
from typing_extensions import Literal
from jinja2 import Template, ext
from pydantic import BaseModel, Field
import copy
import yaml
from typing import List, Optional, Union
import ruamel.yaml
import typer
from ...utils.factory import PipelineFactory, NodeModelFactory, PipelineBase, DataFactory
from ...utils.base_model import extract_name, EarlyStopConfig, DeviceEnum
import yaml
from jinja2 import ext, Template
from pydantic import BaseModel, Field
from ruamel.yaml.comments import CommentedMap
from typing_extensions import Literal
from ...utils.base_model import DeviceEnum, EarlyStopConfig, extract_name
from ...utils.factory import (
DataFactory,
NodeModelFactory,
PipelineBase,
PipelineFactory,
)
from ...utils.yaml_dump import deep_convert_dict, merge_comment
import ruamel.yaml
from ruamel.yaml.comments import CommentedMap
class SamplerConfig(BaseModel):
......@@ -25,8 +32,7 @@ class SamplerConfig(BaseModel):
eval_num_workers: int = 4
class Config:
extra = 'forbid'
extra = "forbid"
pipeline_comments = {
......@@ -34,19 +40,20 @@ pipeline_comments = {
"eval_period": "Interval epochs between evaluations",
"early_stop": {
"patience": "Steps before early stop",
"checkpoint_path": "Early stop checkpoint model file path"
"checkpoint_path": "Early stop checkpoint model file path",
},
"sampler": {
"fan_out": "List of neighbors to sample per edge type for each GNN layer, with the i-th element being the fanout for the i-th GNN layer. Length should be the same as num_layers in model setting",
"batch_size": "Batch size of seed nodes in training stage",
"num_workers": "Number of workers to accelerate the graph data processing step",
"eval_batch_size": "Batch size of seed nodes in training stage in evaluation stage",
"eval_num_workers": "Number of workers to accelerate the graph data processing step in evaluation stage"
"eval_num_workers": "Number of workers to accelerate the graph data processing step in evaluation stage",
},
"save_path": "Directory to save the experiment results",
"num_runs": "Number of experiments to run",
}
class NodepredNSPipelineCfg(BaseModel):
sampler: SamplerConfig = Field("neighbor")
early_stop: Optional[EarlyStopConfig] = EarlyStopConfig()
......@@ -57,22 +64,25 @@ class NodepredNSPipelineCfg(BaseModel):
num_runs: int = 1
save_path: str = "results"
@PipelineFactory.register("nodepred-ns")
class NodepredNsPipeline(PipelineBase):
def __init__(self):
self.pipeline = {
"name": "nodepred-ns",
"mode": "train"
}
self.pipeline = {"name": "nodepred-ns", "mode": "train"}
self.default_cfg = None
@classmethod
def setup_user_cfg_cls(cls):
from ...utils.enter_config import UserConfig
class NodePredUserConfig(UserConfig):
eval_device: DeviceEnum = Field("cpu")
data: DataFactory.filter("nodepred-ns").get_pydantic_config() = Field(..., discriminator="name")
model : NodeModelFactory.filter(lambda cls: hasattr(cls, "forward_block")).get_pydantic_model_config() = Field(..., discriminator="name")
data: DataFactory.filter(
"nodepred-ns"
).get_pydantic_config() = Field(..., discriminator="name")
model: NodeModelFactory.filter(
lambda cls: hasattr(cls, "forward_block")
).get_pydantic_model_config() = Field(..., discriminator="name")
general_pipeline: NodepredNSPipelineCfg
cls.user_cfg_cls = NodePredUserConfig
......@@ -83,10 +93,15 @@ class NodepredNsPipeline(PipelineBase):
def get_cfg_func(self):
def config(
data: DataFactory.filter("nodepred-ns").get_dataset_enum() = typer.Option(..., help="input data name"),
data: DataFactory.filter(
"nodepred-ns"
).get_dataset_enum() = typer.Option(..., help="input data name"),
cfg: Optional[str] = typer.Option(
None, help="output configuration path"),
model: NodeModelFactory.filter(lambda cls: hasattr(cls, "forward_block")).get_model_enum() = typer.Option(..., help="Model name"),
None, help="output configuration path"
),
model: NodeModelFactory.filter(
lambda cls: hasattr(cls, "forward_block")
).get_model_enum() = typer.Option(..., help="Model name"),
):
self.__class__.setup_user_cfg_cls()
generated_cfg = {
......@@ -95,14 +110,14 @@ class NodepredNsPipeline(PipelineBase):
"device": "cpu",
"data": {"name": data.name},
"model": {"name": model.value},
"general_pipeline": {"sampler":{"name": "neighbor"}}
"general_pipeline": {"sampler": {"name": "neighbor"}},
}
output_cfg = self.user_cfg_cls(**generated_cfg).dict()
output_cfg = deep_convert_dict(output_cfg)
comment_dict = {
"device": "Torch device name, e.g., cpu or cuda or cuda:0",
"data": {
"split_ratio": 'Ratio to generate split masks, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset'
"split_ratio": "Ratio to generate split masks, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset"
},
"general_pipeline": pipeline_comments,
"model": NodeModelFactory.get_constructor_doc_dict(model.value),
......@@ -111,14 +126,25 @@ class NodepredNsPipeline(PipelineBase):
# truncate length fan_out to be the same as num_layers in model
if "num_layers" in comment_dict["model"]:
comment_dict['general_pipeline']["sampler"]["fan_out"] = [5,10,15,15,15][:int(comment_dict['model']["num_layers"])]
comment_dict["general_pipeline"]["sampler"]["fan_out"] = [
5,
10,
15,
15,
15,
][: int(comment_dict["model"]["num_layers"])]
if cfg is None:
cfg = "_".join(["nodepred-ns", data.value, model.value]) + ".yaml"
cfg = (
"_".join(["nodepred-ns", data.value, model.value]) + ".yaml"
)
yaml = ruamel.yaml.YAML()
yaml.dump(comment_dict, Path(cfg).open("w"))
print("Configuration file is generated at {}".format(
Path(cfg).absolute()))
print(
"Configuration file is generated at {}".format(
Path(cfg).absolute()
)
)
return config
......@@ -129,20 +155,27 @@ class NodepredNsPipeline(PipelineBase):
with open(template_filename, "r") as f:
template = Template(f.read())
pipeline_cfg = NodepredNSPipelineCfg(
**user_cfg_dict["general_pipeline"])
**user_cfg_dict["general_pipeline"]
)
if "num_layers" in user_cfg_dict["model"]:
assert user_cfg_dict["model"]["num_layers"] == len(user_cfg_dict["general_pipeline"]["sampler"]["fan_out"]), \
"The num_layers in model config should be the same as the length of fan_out in sampler. For example, if num_layers is 1, the fan_out cannot be [5, 10]"
assert user_cfg_dict["model"]["num_layers"] == len(
user_cfg_dict["general_pipeline"]["sampler"]["fan_out"]
), "The num_layers in model config should be the same as the length of fan_out in sampler. For example, if num_layers is 1, the fan_out cannot be [5, 10]"
render_cfg = copy.deepcopy(user_cfg_dict)
model_code = NodeModelFactory.get_source_code(
user_cfg_dict["model"]["name"])
user_cfg_dict["model"]["name"]
)
render_cfg["model_code"] = model_code
render_cfg["model_class_name"] = NodeModelFactory.get_model_class_name(
user_cfg_dict["model"]["name"])
render_cfg.update(DataFactory.get_generated_code_dict(
user_cfg_dict["data"]["name"], '**cfg["data"]'))
user_cfg_dict["model"]["name"]
)
render_cfg.update(
DataFactory.get_generated_code_dict(
user_cfg_dict["data"]["name"], '**cfg["data"]'
)
)
generated_user_cfg = copy.deepcopy(user_cfg_dict)
if "split_ratio" in generated_user_cfg["data"]:
......@@ -150,12 +183,16 @@ class NodepredNsPipeline(PipelineBase):
generated_user_cfg["data_name"] = generated_user_cfg["data"].pop("name")
generated_user_cfg.pop("pipeline_name")
generated_user_cfg.pop("pipeline_mode")
generated_user_cfg["model_name"] = generated_user_cfg["model"].pop("name")
generated_user_cfg['general_pipeline']["optimizer"].pop("name")
generated_user_cfg["model_name"] = generated_user_cfg["model"].pop(
"name"
)
generated_user_cfg["general_pipeline"]["optimizer"].pop("name")
if user_cfg_dict["data"].get("split_ratio", None) is not None:
render_cfg["data_initialize_code"] = "{}, split_ratio={}".format(render_cfg["data_initialize_code"], user_cfg_dict["data"]["split_ratio"])
render_cfg["data_initialize_code"] = "{}, split_ratio={}".format(
render_cfg["data_initialize_code"],
user_cfg_dict["data"]["split_ratio"],
)
render_cfg["user_cfg_str"] = f"cfg = {str(generated_user_cfg)}"
render_cfg["user_cfg"] = user_cfg_dict
......
import copy
import enum
from enum import Enum, IntEnum
from typing import Optional
from jinja2 import Template
from enum import Enum, IntEnum
import copy
from pydantic import create_model, BaseModel as PydanticBaseModel, Field, create_model
from pydantic import (
BaseModel as PydanticBaseModel,
create_model,
create_model,
Field,
)
class DeviceEnum(str, Enum):
cpu = "cpu"
cuda = "cuda"
class DGLBaseModel(PydanticBaseModel):
class Config:
extra = "allow"
......@@ -27,14 +34,16 @@ def get_literal_value(type_):
name = type_.__args__[0]
return name
def extract_name(union_type):
name_dict = {}
for t in union_type.__args__:
type_ = t.__fields__['name'].type_
type_ = t.__fields__["name"].type_
name = get_literal_value(type_)
name_dict[name] = name
return enum.Enum("Choice", name_dict)
class EarlyStopConfig(DGLBaseModel):
patience: int = 20
checkpoint_path: str = "checkpoint.pth"
import torch
class EarlyStopping:
def __init__(self,
patience: int = -1,
checkpoint_path: str = 'checkpoint.pth'):
def __init__(
self, patience: int = -1, checkpoint_path: str = "checkpoint.pth"
):
self.patience = patience
self.checkpoint_path = checkpoint_path
self.counter = 0
......@@ -17,7 +18,9 @@ class EarlyStopping:
self.save_checkpoint(model)
elif score < self.best_score:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
print(
f"EarlyStopping counter: {self.counter} out of {self.patience}"
)
if self.counter >= self.patience:
self.early_stop = True
else:
......@@ -27,7 +30,7 @@ class EarlyStopping:
return self.early_stop
def save_checkpoint(self, model):
'''Save model when validation loss decreases.'''
"""Save model when validation loss decreases."""
torch.save(model.state_dict(), self.checkpoint_path)
def load_checkpoint(self, model):
......
import copy
from enum import Enum, IntEnum
from typing import Optional
import yaml
import jinja2
import yaml
from jinja2 import Template
from enum import Enum, IntEnum
import copy
from pydantic import create_model, BaseModel as PydanticBaseModel, Field
# from ..pipeline import nodepred, nodepred_sample
from .factory import ModelFactory, PipelineFactory, DataFactory
from .base_model import DGLBaseModel
from pydantic import BaseModel as PydanticBaseModel, create_model, Field
from .base_model import DGLBaseModel
# from ..pipeline import nodepred, nodepred_sample
from .factory import DataFactory, ModelFactory, PipelineFactory
class PipelineConfig(DGLBaseModel):
......@@ -22,6 +21,7 @@ class PipelineConfig(DGLBaseModel):
optimizer: dict = {"name": "Adam", "lr": 0.005}
loss: str = "CrossEntropyLoss"
class UserConfig(DGLBaseModel):
version: Optional[str] = "0.0.2"
pipeline_name: PipelineFactory.get_pipeline_enum()
......
import enum
import inspect
import logging
from typing import Callable, Dict, Union, List, Tuple, Optional
from typing_extensions import Literal
from pathlib import Path
from abc import ABC, abstractmethod, abstractstaticmethod
from .base_model import DGLBaseModel
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
import yaml
from pydantic import create_model_from_typeddict, create_model, Field
from dgl.dataloading.negative_sampler import GlobalUniform, PerSourceUniform
import inspect
from numpydoc import docscrape
from pydantic import create_model, create_model_from_typeddict, Field
from typing_extensions import Literal
from .base_model import DGLBaseModel
logger = logging.getLogger(__name__)
ALL_PIPELINE = ["nodepred", "nodepred-ns", "linkpred", "graphpred"]
class PipelineBase(ABC):
class PipelineBase(ABC):
@abstractmethod
def __init__(self) -> None:
super().__init__()
......@@ -34,23 +37,24 @@ class PipelineBase(ABC):
class DataFactoryClass:
def __init__(self):
self.registry = {}
self.pipeline_name = None
self.pipeline_allowed = {}
def register(self,
def register(
self,
name: str,
import_code: str,
class_name: str,
allowed_pipeline: List[str],
extra_args={}):
extra_args={},
):
self.registry[name] = {
"name": name,
"import_code": import_code,
"class_name": class_name,
"extra_args": extra_args
"extra_args": extra_args,
}
for pipeline in allowed_pipeline:
if pipeline in self.pipeline_allowed:
......@@ -61,7 +65,8 @@ class DataFactoryClass:
def get_dataset_enum(self):
enum_class = enum.Enum(
"DatasetName", {v["name"]: k for k, v in self.registry.items()})
"DatasetName", {v["name"]: k for k, v in self.registry.items()}
)
return enum_class
def get_dataset_classname(self, name):
......@@ -84,8 +89,13 @@ class DataFactoryClass:
if "name" in type_annotation_dict:
del type_annotation_dict["name"]
base = self.get_base_class(dataset_name, self.pipeline_name)
dataset_list.append(create_model(
f'{dataset_name}Config', **type_annotation_dict, __base__=base))
dataset_list.append(
create_model(
f"{dataset_name}Config",
**type_annotation_dict,
__base__=base,
)
)
output = dataset_list[0]
for d in dataset_list[1:]:
......@@ -116,7 +126,9 @@ class DataFactoryClass:
def filter(self, pipeline_name):
allowed_name = self.pipeline_allowed[pipeline_name]
new_registry = {k: v for k,v in self.registry.items() if k in allowed_name}
new_registry = {
k: v for k, v in self.registry.items() if k in allowed_name
}
d = DataFactoryClass()
d.registry = new_registry
d.pipeline_name = pipeline_name
......@@ -125,18 +137,20 @@ class DataFactoryClass:
@staticmethod
def get_base_class(dataset_name, pipeline_name):
if pipeline_name == "linkpred":
class EdgeBase(DGLBaseModel):
name: Literal[dataset_name]
split_ratio: Optional[Tuple[float, float, float]] = None
neg_ratio: Optional[int] = None
return EdgeBase
else:
class NodeBase(DGLBaseModel):
name: Literal[dataset_name]
split_ratio: Optional[Tuple[float, float, float]] = None
return NodeBase
return NodeBase
DataFactory = DataFactoryClass()
......@@ -145,83 +159,96 @@ DataFactory.register(
"cora",
import_code="from dgl.data import CoraGraphDataset",
class_name="CoraGraphDataset()",
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"])
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"],
)
DataFactory.register(
"citeseer",
import_code="from dgl.data import CiteseerGraphDataset",
class_name="CiteseerGraphDataset()",
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"])
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"],
)
DataFactory.register(
"pubmed",
import_code="from dgl.data import PubmedGraphDataset",
class_name="PubmedGraphDataset()",
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"])
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"],
)
DataFactory.register(
"csv",
import_code="from dgl.data import CSVDataset",
extra_args={"data_path": "./"},
class_name="CSVDataset({})",
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred", "graphpred"])
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred", "graphpred"],
)
DataFactory.register(
"reddit",
import_code="from dgl.data import RedditDataset",
class_name="RedditDataset()",
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"])
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"],
)
DataFactory.register(
"co-buy-computer",
import_code="from dgl.data import AmazonCoBuyComputerDataset",
class_name="AmazonCoBuyComputerDataset()",
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"])
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"],
)
DataFactory.register(
"ogbn-arxiv",
import_code="from ogb.nodeproppred import DglNodePropPredDataset",
extra_args={},
class_name="DglNodePropPredDataset('ogbn-arxiv')",
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"])
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"],
)
DataFactory.register(
"ogbn-products",
import_code="from ogb.nodeproppred import DglNodePropPredDataset",
extra_args={},
class_name="DglNodePropPredDataset('ogbn-products')",
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"])
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"],
)
DataFactory.register(
"ogbl-collab",
import_code="from ogb.linkproppred import DglLinkPropPredDataset",
extra_args={},
class_name="DglLinkPropPredDataset('ogbl-collab')",
allowed_pipeline=["linkpred"])
allowed_pipeline=["linkpred"],
)
DataFactory.register(
"ogbl-citation2",
import_code="from ogb.linkproppred import DglLinkPropPredDataset",
extra_args={},
class_name="DglLinkPropPredDataset('ogbl-citation2')",
allowed_pipeline=["linkpred"])
allowed_pipeline=["linkpred"],
)
DataFactory.register(
"ogbg-molhiv",
import_code="from ogb.graphproppred import DglGraphPropPredDataset",
extra_args={},
class_name="DglGraphPropPredDataset(name='ogbg-molhiv')",
allowed_pipeline=["graphpred"])
allowed_pipeline=["graphpred"],
)
DataFactory.register(
"ogbg-molpcba",
import_code="from ogb.graphproppred import DglGraphPropPredDataset",
extra_args={},
class_name="DglGraphPropPredDataset(name='ogbg-molpcba')",
allowed_pipeline=["graphpred"])
allowed_pipeline=["graphpred"],
)
class PipelineFactory:
""" The factory class for creating executors"""
"""The factory class for creating executors"""
registry: Dict[str, PipelineBase] = {}
default_config_registry = {}
......@@ -229,11 +256,11 @@ class PipelineFactory:
@classmethod
def register(cls, name: str) -> Callable:
def inner_wrapper(wrapped_class) -> Callable:
if name in cls.registry:
logger.warning(
'Executor %s already exists. Will replace it', name)
"Executor %s already exists. Will replace it", name
)
cls.registry[name] = wrapped_class()
return wrapped_class
......@@ -241,19 +268,23 @@ class PipelineFactory:
@classmethod
def register_default_config_generator(cls, name: str) -> Callable:
def inner_wrapper(wrapped_class) -> Callable:
if name in cls.registry:
logger.warning(
'Executor %s already exists. Will replace it', name)
"Executor %s already exists. Will replace it", name
)
cls.default_config_registry[name] = wrapped_class
return wrapped_class
return inner_wrapper
@classmethod
def call_default_config_generator(cls, generator_name, model_name, dataset_name):
return cls.default_config_registry[generator_name](model_name, dataset_name)
def call_default_config_generator(
cls, generator_name, model_name, dataset_name
):
return cls.default_config_registry[generator_name](
model_name, dataset_name
)
@classmethod
def call_generator(cls, generator_name, cfg):
......@@ -262,9 +293,11 @@ class PipelineFactory:
@classmethod
def get_pipeline_enum(cls):
enum_class = enum.Enum(
"PipelineName", {k: k for k, v in cls.registry.items()})
"PipelineName", {k: k for k, v in cls.registry.items()}
)
return enum_class
class ApplyPipelineFactory:
"""The factory class for creating executors for inference"""
......@@ -273,38 +306,41 @@ class ApplyPipelineFactory:
@classmethod
def register(cls, name: str) -> Callable:
def inner_wrapper(wrapped_class) -> Callable:
if name in cls.registry:
logger.warning(
'Executor %s already exists. Will replace it', name)
"Executor %s already exists. Will replace it", name
)
cls.registry[name] = wrapped_class()
return wrapped_class
return inner_wrapper
model_dir = Path(__file__).parent.parent / "model"
class ModelFactory:
""" The factory class for creating executors"""
"""The factory class for creating executors"""
def __init__(self):
self.registry = {}
self.code_registry = {}
""" Internal registry for available executors """
def get_model_enum(self):
enum_class = enum.Enum(
"ModelName", {k: k for k, v in self.registry.items()})
"ModelName", {k: k for k, v in self.registry.items()}
)
return enum_class
def register(self, model_name: str) -> Callable:
def inner_wrapper(wrapped_class) -> Callable:
if model_name in self.registry:
logger.warning(
'Executor %s already exists. Will replace it', model_name)
"Executor %s already exists. Will replace it", model_name
)
self.registry[model_name] = wrapped_class
# code_filename = model_dir / filename
code_filename = Path(inspect.getfile(wrapped_class))
......@@ -328,14 +364,19 @@ class ModelFactory:
arg_dict = self.get_constructor_default_args(model_name)
type_annotation_dict = {}
# type_annotation_dict["name"] = Literal[""]
exempt_keys = ['self', 'in_size', 'out_size', 'data_info']
exempt_keys = ["self", "in_size", "out_size", "data_info"]
for k, param in arg_dict.items():
if k not in exempt_keys:
type_annotation_dict[k] = arg_dict[k]
class Base(DGLBaseModel):
name: Literal[model_name]
return create_model(f'{model_name.upper()}ModelConfig', **type_annotation_dict, __base__=Base)
return create_model(
f"{model_name.upper()}ModelConfig",
**type_annotation_dict,
__base__=Base,
)
def get_constructor_doc_dict(self, name):
model_class = self.registry[name]
......@@ -375,22 +416,23 @@ class ModelFactory:
class SamplerFactory:
""" The factory class for creating executors"""
"""The factory class for creating executors"""
def __init__(self):
self.registry = {}
def get_model_enum(self):
enum_class = enum.Enum(
"NegativeSamplerName", {k: k for k, v in self.registry.items()})
"NegativeSamplerName", {k: k for k, v in self.registry.items()}
)
return enum_class
def register(self, sampler_name: str) -> Callable:
def inner_wrapper(wrapped_class) -> Callable:
if sampler_name in self.registry:
logger.warning(
'Sampler %s already exists. Will replace it', sampler_name)
"Sampler %s already exists. Will replace it", sampler_name
)
self.registry[sampler_name] = wrapped_class
return wrapped_class
......@@ -408,17 +450,22 @@ class SamplerFactory:
arg_dict = self.get_constructor_default_args(sampler_name)
type_annotation_dict = {}
# type_annotation_dict["name"] = Literal[""]
exempt_keys = ['self', 'in_size', 'out_size', 'redundancy']
exempt_keys = ["self", "in_size", "out_size", "redundancy"]
for k, param in arg_dict.items():
if k not in exempt_keys or param is None:
if k == 'k' or k == 'redundancy':
if k == "k" or k == "redundancy":
type_annotation_dict[k] = 3
else:
type_annotation_dict[k] = arg_dict[k]
class Base(DGLBaseModel):
name: Literal[sampler_name]
return create_model(f'{sampler_name.upper()}SamplerConfig', **type_annotation_dict, __base__=Base)
return create_model(
f"{sampler_name.upper()}SamplerConfig",
**type_annotation_dict,
__base__=Base,
)
def get_pydantic_model_config(self):
model_list = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment