test_pipeline.py 6.17 KB
Newer Older
Mufei Li's avatar
Mufei Li committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import pytest

@pytest.mark.parametrize('data', ['cora', 'citeseer', 'pubmed', 'csv', 'reddit',
                                  'co-buy-computer', 'ogbn-arxiv', 'ogbn-products'])
@pytest.mark.parametrize('model', ['gcn', 'gat', 'sage', 'sgc', 'gin'])
def test_nodepred(data, model):
    os.system('dgl configure nodepred --data {} --model {}'.format(data, model))
    assert os.path.exists('nodepred_{}_{}.yaml'.format(data, model))

    custom_config_file = 'custom_{}_{}.yaml'.format(data, model)
    os.system('dgl configure nodepred --data {} --model {} --cfg {}'.format(data, model,
                                                                            custom_config_file))
    assert os.path.exists(custom_config_file)

    custom_script = '_'.join([data, model]) + '.py'
    os.system('dgl export --cfg {} --output {}'.format(custom_config_file, custom_script))
    assert os.path.exists(custom_script)

@pytest.mark.parametrize('data', ['cora', 'citeseer', 'pubmed', 'csv', 'reddit',
                                  'co-buy-computer', 'ogbn-arxiv', 'ogbn-products'])
@pytest.mark.parametrize('model', ['gcn', 'gat', 'sage'])
def test_nodepred_ns(data, model):
    os.system('dgl configure nodepred-ns --data {} --model {}'.format(data, model))
    assert os.path.exists('nodepred-ns_{}_{}.yaml'.format(data, model))

    custom_config_file = 'custom_{}_{}.yaml'.format(data, model)
    os.system('dgl configure nodepred-ns --data {} --model {} --cfg {}'.format(data, model,
                                                                               custom_config_file))
    assert os.path.exists(custom_config_file)

    custom_script = '_'.join([data, model]) + '.py'
    os.system('dgl export --cfg {} --output {}'.format(custom_config_file, custom_script))
    assert os.path.exists(custom_script)

@pytest.mark.parametrize('data', ['cora', 'citeseer', 'pubmed', 'csv', 'reddit',
                                  'co-buy-computer', 'ogbn-arxiv', 'ogbn-products', 'ogbl-collab',
                                  'ogbl-citation2'])
@pytest.mark.parametrize('node_model', ['gcn' ,'gat', 'sage', 'sgc', 'gin'])
@pytest.mark.parametrize('edge_model', ['ele', 'bilinear'])
@pytest.mark.parametrize('neg_sampler', ['global', 'persource'])
def test_linkpred(data, node_model, edge_model, neg_sampler):
    custom_config_file = '_'.join([data, node_model, edge_model, neg_sampler]) + '.yaml'
    os.system('dgl configure linkpred --data {} --node-model {} --edge-model {} --neg-sampler {} --cfg {}'.format(
        data, node_model, edge_model, neg_sampler, custom_config_file))
    assert os.path.exists(custom_config_file)

    custom_script = '_'.join([data, node_model, edge_model, neg_sampler]) + '.py'
    os.system('dgl export --cfg {} --output {}'.format(custom_config_file, custom_script))
    assert os.path.exists(custom_script)

@pytest.mark.parametrize('data', ['cora', 'citeseer', 'pubmed', 'csv', 'reddit',
                                  'co-buy-computer', 'ogbn-arxiv', 'ogbn-products', 'ogbl-collab',
                                  'ogbl-citation2'])
@pytest.mark.parametrize('node_model', ['gcn' ,'gat', 'sage', 'sgc', 'gin'])
@pytest.mark.parametrize('edge_model', ['ele', 'bilinear'])
def test_linkpred_default_neg_sampler(data, node_model, edge_model):
    custom_config_file = '_'.join([data, node_model, edge_model]) + '.yaml'
    os.system('dgl configure linkpred --data {} --node-model {} --edge-model {} --cfg {}'.format(
        data, node_model, edge_model, custom_config_file))
    assert os.path.exists(custom_config_file)

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@pytest.mark.parametrize('data', ['csv', 'ogbg-molhiv', 'ogbg-molpcba'])
@pytest.mark.parametrize('model', ['gin', 'pna'])
def test_graphpred(data, model):
    os.system('dgl configure graphpred --data {} --model {}'.format(data, model))
    assert os.path.exists('graphpred_{}_{}.yaml'.format(data, model))

    custom_config_file = 'custom_{}_{}.yaml'.format(data, model)
    os.system('dgl configure graphpred --data {} --model {} --cfg {}'.format(data, model,
                                                                             custom_config_file))
    assert os.path.exists(custom_config_file)

    custom_script = '_'.join([data, model]) + '.py'
    os.system('dgl export --cfg {} --output {}'.format(custom_config_file, custom_script))
    assert os.path.exists(custom_script)

Mufei Li's avatar
Mufei Li committed
78
@pytest.mark.parametrize('recipe',
79
80
81
82
                         ['graphpred_hiv_gin.yaml',
                          'graphpred_hiv_pna.yaml',
                          'graphpred_pcba_gin.yaml',
                          'linkpred_cora_sage.yaml',
Mufei Li's avatar
Mufei Li committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
                          'linkpred_citation2_sage.yaml',
                          'linkpred_collab_sage.yaml',
                          'nodepred_citeseer_gat.yaml',
                          'nodepred_citeseer_gcn.yaml',
                          'nodepred_citeseer_sage.yaml',
                          'nodepred_cora_gat.yaml',
                          'nodepred_cora_gcn.yaml',
                          'nodepred_cora_sage.yaml',
                          'nodepred_pubmed_gat.yaml',
                          'nodepred_pubmed_gcn.yaml',
                          'nodepred_pubmed_sage.yaml',
                          'nodepred-ns_arxiv_gcn.yaml',
                          'nodepred-ns_product_sage.yaml'])
def test_recipe(recipe):
    # Remove all generated yaml files
    current_dir = os.listdir("./")
    for item in current_dir:
        if item.endswith(".yaml"):
            os.remove(item)

    os.system('dgl recipe get {}'.format(recipe))
    assert os.path.exists(recipe)

def test_node_cora():
    os.system('dgl configure nodepred --data cora --model gcn')
    os.system('dgl train --cfg nodepred_cora_gcn.yaml')
109
110
111
112
113
114
115
116
117
118
    assert os.path.exists('results')
    assert os.path.exists('results/run_0.pth')
    os.system('dgl configure-apply nodepred --cpt results/run_0.pth')
    assert os.path.exists('apply_nodepred_cora_gcn.yaml')
    os.system('dgl configure-apply nodepred --data cora --cpt results/run_0.pth --cfg apply.yaml')
    assert os.path.exists('apply.yaml')
    os.system('dgl apply --cfg apply.yaml')
    assert os.path.exists('apply_results/output.csv')
    os.system('dgl export --cfg apply.yaml --output apply.py')
    assert os.path.exists('apply.py')