import os import pytest @pytest.mark.parametrize( "data", [ "cora", "citeseer", "pubmed", "csv", "reddit", "co-buy-computer", "ogbn-arxiv", "ogbn-products", ], ) def test_nodepred_data(data): os.system(f"dgl configure nodepred --data {data} --model gcn") assert os.path.exists(f"nodepred_{data}_gcn.yaml") custom_cfg = f"custom_{data}_gcn.yaml" os.system( f"dgl configure nodepred --data {data} --model gcn --cfg {custom_cfg}" ) assert os.path.exists(custom_cfg) custom_script = f"{data}_gcn.py" os.system(f"dgl export --cfg {custom_cfg} --output {custom_script}") assert os.path.exists(custom_script) @pytest.mark.parametrize("model", ["gcn", "gat", "sage", "sgc", "gin"]) def test_nodepred_model(model): os.system(f"dgl configure nodepred --data cora --model {model}") assert os.path.exists(f"nodepred_cora_{model}.yaml") custom_cfg = f"custom_cora_{model}.yaml" os.system( f"dgl configure nodepred --data cora --model {model} --cfg {custom_cfg}" ) assert os.path.exists(custom_cfg) custom_script = f"cora_{model}.py" os.system(f"dgl export --cfg {custom_cfg} --output {custom_script}") assert os.path.exists(custom_script) @pytest.mark.parametrize( "data", [ "cora", "citeseer", "pubmed", "csv", "reddit", "co-buy-computer", "ogbn-arxiv", "ogbn-products", ], ) def test_nodepred_ns_data(data): os.system(f"dgl configure nodepred-ns --data {data} --model gcn") assert os.path.exists(f"nodepred-ns_{data}_gcn.yaml") custom_cfg = f"ns-custom_{data}_gcn.yaml" os.system( f"dgl configure nodepred-ns --data {data} --model gcn --cfg {custom_cfg}" ) assert os.path.exists(custom_cfg) custom_script = f"ns-{data}_gcn.py" os.system(f"dgl export --cfg {custom_cfg} --output {custom_script}") assert os.path.exists(custom_script) @pytest.mark.parametrize("model", ["gcn", "gat", "sage"]) def test_nodepred_ns_model(model): os.system(f"dgl configure nodepred-ns --data cora --model {model}") assert os.path.exists(f"nodepred-ns_cora_{model}.yaml") custom_cfg = f"ns-custom_cora_{model}.yaml" os.system( f"dgl configure nodepred-ns --data cora --model {model} --cfg {custom_cfg}" ) assert os.path.exists(custom_cfg) custom_script = f"ns-cora_{model}.py" os.system(f"dgl export --cfg {custom_cfg} --output {custom_script}") assert os.path.exists(custom_script) @pytest.mark.parametrize( "data", [ "cora", "citeseer", "pubmed", "csv", "reddit", "co-buy-computer", "ogbn-arxiv", "ogbn-products", "ogbl-collab", "ogbl-citation2", ], ) def test_linkpred_data(data): node_model = "gcn" edge_model = "ele" neg_sampler = "global" custom_cfg = "_".join([data, node_model, edge_model, neg_sampler]) + ".yaml" os.system( "dgl configure linkpred --data {} --node-model {} --edge-model {} --neg-sampler {} --cfg {}".format( data, node_model, edge_model, neg_sampler, custom_cfg ) ) assert os.path.exists(custom_cfg) custom_script = ( "_".join([data, node_model, edge_model, neg_sampler]) + ".py" ) os.system( "dgl export --cfg {} --output {}".format(custom_cfg, custom_script) ) assert os.path.exists(custom_script) @pytest.mark.parametrize("node_model", ["gcn", "gat", "sage", "sgc", "gin"]) def test_linkpred_node_model(node_model): data = "cora" edge_model = "ele" neg_sampler = "global" custom_cfg = "_".join([data, node_model, edge_model, neg_sampler]) + ".yaml" os.system( "dgl configure linkpred --data {} --node-model {} --edge-model {} --neg-sampler {} --cfg {}".format( data, node_model, edge_model, neg_sampler, custom_cfg ) ) assert os.path.exists(custom_cfg) custom_script = ( "_".join([data, node_model, edge_model, neg_sampler]) + ".py" ) os.system( "dgl export --cfg {} --output {}".format(custom_cfg, custom_script) ) assert os.path.exists(custom_script) @pytest.mark.parametrize("edge_model", ["ele", "bilinear"]) def test_linkpred_edge_model(edge_model): data = "cora" node_model = "gcn" neg_sampler = "global" custom_cfg = "_".join([data, node_model, edge_model, neg_sampler]) + ".yaml" os.system( "dgl configure linkpred --data {} --node-model {} --edge-model {} --neg-sampler {} --cfg {}".format( data, node_model, edge_model, neg_sampler, custom_cfg ) ) assert os.path.exists(custom_cfg) custom_script = ( "_".join([data, node_model, edge_model, neg_sampler]) + ".py" ) os.system( "dgl export --cfg {} --output {}".format(custom_cfg, custom_script) ) assert os.path.exists(custom_script) @pytest.mark.parametrize("neg_sampler", ["global", "persource", ""]) def test_linkpred_neg_sampler(neg_sampler): data = "cora" node_model = "gcn" edge_model = "ele" custom_cfg = f"{data}_{node_model}_{edge_model}_{neg_sampler}.yaml" if neg_sampler == "": os.system( "dgl configure linkpred --data {} --node-model {} --edge-model {} --cfg {}".format( data, node_model, edge_model, custom_cfg ) ) else: os.system( "dgl configure linkpred --data {} --node-model {} --edge-model {} --neg-sampler {} --cfg {}".format( data, node_model, edge_model, neg_sampler, custom_cfg ) ) assert os.path.exists(custom_cfg) custom_script = f"{data}_{node_model}_{edge_model}_{neg_sampler}.py" os.system( "dgl export --cfg {} --output {}".format(custom_cfg, custom_script) ) assert os.path.exists(custom_script) @pytest.mark.parametrize("data", ["csv", "ogbg-molhiv", "ogbg-molpcba"]) @pytest.mark.parametrize("model", ["gin", "pna"]) def test_graphpred(data, model): os.system( "dgl configure graphpred --data {} --model {}".format(data, model) ) assert os.path.exists("graphpred_{}_{}.yaml".format(data, model)) custom_cfg = "custom_{}_{}.yaml".format(data, model) os.system( "dgl configure graphpred --data {} --model {} --cfg {}".format( data, model, custom_cfg ) ) assert os.path.exists(custom_cfg) custom_script = "_".join([data, model]) + ".py" os.system( "dgl export --cfg {} --output {}".format(custom_cfg, custom_script) ) assert os.path.exists(custom_script) @pytest.mark.parametrize( "recipe", [ "graphpred_hiv_gin.yaml", "graphpred_hiv_pna.yaml", "graphpred_pcba_gin.yaml", "linkpred_cora_sage.yaml", "linkpred_citation2_sage.yaml", "linkpred_collab_sage.yaml", "nodepred_citeseer_gat.yaml", "nodepred_citeseer_gcn.yaml", "nodepred_citeseer_sage.yaml", "nodepred_cora_gat.yaml", "nodepred_cora_gcn.yaml", "nodepred_cora_sage.yaml", "nodepred_pubmed_gat.yaml", "nodepred_pubmed_gcn.yaml", "nodepred_pubmed_sage.yaml", "nodepred-ns_arxiv_gcn.yaml", "nodepred-ns_product_sage.yaml", ], ) def test_recipe(recipe): # Remove all generated yaml files current_dir = os.listdir("./") for item in current_dir: if item.endswith(".yaml"): os.remove(item) os.system("dgl recipe get {}".format(recipe)) assert os.path.exists(recipe) def test_node_cora(): os.system("dgl configure nodepred --data cora --model gcn") os.system("dgl train --cfg nodepred_cora_gcn.yaml") assert os.path.exists("results") assert os.path.exists("results/run_0.pth") os.system("dgl configure-apply nodepred --cpt results/run_0.pth") assert os.path.exists("apply_nodepred_cora_gcn.yaml") os.system( "dgl configure-apply nodepred --data cora --cpt results/run_0.pth --cfg apply.yaml" ) assert os.path.exists("apply.yaml") os.system("dgl apply --cfg apply.yaml") assert os.path.exists("apply_results/output.csv") os.system("dgl export --cfg apply.yaml --output apply.py") assert os.path.exists("apply.py")