test_pipeline.py 5.71 KB
Newer Older
Mufei Li's avatar
Mufei Li committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import pytest

@pytest.mark.parametrize('data', ['cora', 'citeseer', 'pubmed', 'csv', 'reddit',
                                  'co-buy-computer', 'ogbn-arxiv', 'ogbn-products'])
@pytest.mark.parametrize('model', ['gcn', 'gat', 'sage', 'sgc', 'gin'])
def test_nodepred(data, model):
    os.system('dgl configure nodepred --data {} --model {}'.format(data, model))
    assert os.path.exists('nodepred_{}_{}.yaml'.format(data, model))

    custom_config_file = 'custom_{}_{}.yaml'.format(data, model)
    os.system('dgl configure nodepred --data {} --model {} --cfg {}'.format(data, model,
                                                                            custom_config_file))
    assert os.path.exists(custom_config_file)

    custom_script = '_'.join([data, model]) + '.py'
    os.system('dgl export --cfg {} --output {}'.format(custom_config_file, custom_script))
    assert os.path.exists(custom_script)

@pytest.mark.parametrize('data', ['cora', 'citeseer', 'pubmed', 'csv', 'reddit',
                                  'co-buy-computer', 'ogbn-arxiv', 'ogbn-products'])
@pytest.mark.parametrize('model', ['gcn', 'gat', 'sage'])
def test_nodepred_ns(data, model):
    os.system('dgl configure nodepred-ns --data {} --model {}'.format(data, model))
    assert os.path.exists('nodepred-ns_{}_{}.yaml'.format(data, model))

    custom_config_file = 'custom_{}_{}.yaml'.format(data, model)
    os.system('dgl configure nodepred-ns --data {} --model {} --cfg {}'.format(data, model,
                                                                               custom_config_file))
    assert os.path.exists(custom_config_file)

    custom_script = '_'.join([data, model]) + '.py'
    os.system('dgl export --cfg {} --output {}'.format(custom_config_file, custom_script))
    assert os.path.exists(custom_script)

@pytest.mark.parametrize('data', ['cora', 'citeseer', 'pubmed', 'csv', 'reddit',
                                  'co-buy-computer', 'ogbn-arxiv', 'ogbn-products', 'ogbl-collab',
                                  'ogbl-citation2'])
@pytest.mark.parametrize('node_model', ['gcn' ,'gat', 'sage', 'sgc', 'gin'])
@pytest.mark.parametrize('edge_model', ['ele', 'bilinear'])
@pytest.mark.parametrize('neg_sampler', ['global', 'persource'])
def test_linkpred(data, node_model, edge_model, neg_sampler):
    custom_config_file = '_'.join([data, node_model, edge_model, neg_sampler]) + '.yaml'
    os.system('dgl configure linkpred --data {} --node-model {} --edge-model {} --neg-sampler {} --cfg {}'.format(
        data, node_model, edge_model, neg_sampler, custom_config_file))
    assert os.path.exists(custom_config_file)

    custom_script = '_'.join([data, node_model, edge_model, neg_sampler]) + '.py'
    os.system('dgl export --cfg {} --output {}'.format(custom_config_file, custom_script))
    assert os.path.exists(custom_script)

@pytest.mark.parametrize('data', ['cora', 'citeseer', 'pubmed', 'csv', 'reddit',
                                  'co-buy-computer', 'ogbn-arxiv', 'ogbn-products', 'ogbl-collab',
                                  'ogbl-citation2'])
@pytest.mark.parametrize('node_model', ['gcn' ,'gat', 'sage', 'sgc', 'gin'])
@pytest.mark.parametrize('edge_model', ['ele', 'bilinear'])
def test_linkpred_default_neg_sampler(data, node_model, edge_model):
    custom_config_file = '_'.join([data, node_model, edge_model]) + '.yaml'
    os.system('dgl configure linkpred --data {} --node-model {} --edge-model {} --cfg {}'.format(
        data, node_model, edge_model, custom_config_file))
    assert os.path.exists(custom_config_file)

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@pytest.mark.parametrize('data', ['csv', 'ogbg-molhiv', 'ogbg-molpcba'])
@pytest.mark.parametrize('model', ['gin', 'pna'])
def test_graphpred(data, model):
    os.system('dgl configure graphpred --data {} --model {}'.format(data, model))
    assert os.path.exists('graphpred_{}_{}.yaml'.format(data, model))

    custom_config_file = 'custom_{}_{}.yaml'.format(data, model)
    os.system('dgl configure graphpred --data {} --model {} --cfg {}'.format(data, model,
                                                                             custom_config_file))
    assert os.path.exists(custom_config_file)

    custom_script = '_'.join([data, model]) + '.py'
    os.system('dgl export --cfg {} --output {}'.format(custom_config_file, custom_script))
    assert os.path.exists(custom_script)

Mufei Li's avatar
Mufei Li committed
78
@pytest.mark.parametrize('recipe',
79
80
81
82
                         ['graphpred_hiv_gin.yaml',
                          'graphpred_hiv_pna.yaml',
                          'graphpred_pcba_gin.yaml',
                          'linkpred_cora_sage.yaml',
Mufei Li's avatar
Mufei Li committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
                          'linkpred_citation2_sage.yaml',
                          'linkpred_collab_sage.yaml',
                          'nodepred_citeseer_gat.yaml',
                          'nodepred_citeseer_gcn.yaml',
                          'nodepred_citeseer_sage.yaml',
                          'nodepred_cora_gat.yaml',
                          'nodepred_cora_gcn.yaml',
                          'nodepred_cora_sage.yaml',
                          'nodepred_pubmed_gat.yaml',
                          'nodepred_pubmed_gcn.yaml',
                          'nodepred_pubmed_sage.yaml',
                          'nodepred-ns_arxiv_gcn.yaml',
                          'nodepred-ns_product_sage.yaml'])
def test_recipe(recipe):
    # Remove all generated yaml files
    current_dir = os.listdir("./")
    for item in current_dir:
        if item.endswith(".yaml"):
            os.remove(item)

    os.system('dgl recipe get {}'.format(recipe))
    assert os.path.exists(recipe)

def test_node_cora():
    os.system('dgl configure nodepred --data cora --model gcn')
    os.system('dgl train --cfg nodepred_cora_gcn.yaml')
    assert os.path.exists('checkpoint.pth')
    assert os.path.exists('model.pth')