test_pipeline.py 4.78 KB
Newer Older
Mufei Li's avatar
Mufei Li committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import pytest

@pytest.mark.parametrize('data', ['cora', 'citeseer', 'pubmed', 'csv', 'reddit',
                                  'co-buy-computer', 'ogbn-arxiv', 'ogbn-products'])
@pytest.mark.parametrize('model', ['gcn', 'gat', 'sage', 'sgc', 'gin'])
def test_nodepred(data, model):
    os.system('dgl configure nodepred --data {} --model {}'.format(data, model))
    assert os.path.exists('nodepred_{}_{}.yaml'.format(data, model))

    custom_config_file = 'custom_{}_{}.yaml'.format(data, model)
    os.system('dgl configure nodepred --data {} --model {} --cfg {}'.format(data, model,
                                                                            custom_config_file))
    assert os.path.exists(custom_config_file)

    custom_script = '_'.join([data, model]) + '.py'
    os.system('dgl export --cfg {} --output {}'.format(custom_config_file, custom_script))
    assert os.path.exists(custom_script)

@pytest.mark.parametrize('data', ['cora', 'citeseer', 'pubmed', 'csv', 'reddit',
                                  'co-buy-computer', 'ogbn-arxiv', 'ogbn-products'])
@pytest.mark.parametrize('model', ['gcn', 'gat', 'sage'])
def test_nodepred_ns(data, model):
    os.system('dgl configure nodepred-ns --data {} --model {}'.format(data, model))
    assert os.path.exists('nodepred-ns_{}_{}.yaml'.format(data, model))

    custom_config_file = 'custom_{}_{}.yaml'.format(data, model)
    os.system('dgl configure nodepred-ns --data {} --model {} --cfg {}'.format(data, model,
                                                                               custom_config_file))
    assert os.path.exists(custom_config_file)

    custom_script = '_'.join([data, model]) + '.py'
    os.system('dgl export --cfg {} --output {}'.format(custom_config_file, custom_script))
    assert os.path.exists(custom_script)

@pytest.mark.parametrize('data', ['cora', 'citeseer', 'pubmed', 'csv', 'reddit',
                                  'co-buy-computer', 'ogbn-arxiv', 'ogbn-products', 'ogbl-collab',
                                  'ogbl-citation2'])
@pytest.mark.parametrize('node_model', ['gcn' ,'gat', 'sage', 'sgc', 'gin'])
@pytest.mark.parametrize('edge_model', ['ele', 'bilinear'])
@pytest.mark.parametrize('neg_sampler', ['global', 'persource'])
def test_linkpred(data, node_model, edge_model, neg_sampler):
    custom_config_file = '_'.join([data, node_model, edge_model, neg_sampler]) + '.yaml'
    os.system('dgl configure linkpred --data {} --node-model {} --edge-model {} --neg-sampler {} --cfg {}'.format(
        data, node_model, edge_model, neg_sampler, custom_config_file))
    assert os.path.exists(custom_config_file)

    custom_script = '_'.join([data, node_model, edge_model, neg_sampler]) + '.py'
    os.system('dgl export --cfg {} --output {}'.format(custom_config_file, custom_script))
    assert os.path.exists(custom_script)

@pytest.mark.parametrize('data', ['cora', 'citeseer', 'pubmed', 'csv', 'reddit',
                                  'co-buy-computer', 'ogbn-arxiv', 'ogbn-products', 'ogbl-collab',
                                  'ogbl-citation2'])
@pytest.mark.parametrize('node_model', ['gcn' ,'gat', 'sage', 'sgc', 'gin'])
@pytest.mark.parametrize('edge_model', ['ele', 'bilinear'])
def test_linkpred_default_neg_sampler(data, node_model, edge_model):
    custom_config_file = '_'.join([data, node_model, edge_model]) + '.yaml'
    os.system('dgl configure linkpred --data {} --node-model {} --edge-model {} --cfg {}'.format(
        data, node_model, edge_model, custom_config_file))
    assert os.path.exists(custom_config_file)

@pytest.mark.parametrize('recipe',
                         ['linkpred_cora_sage.yaml',
                          'linkpred_citation2_sage.yaml',
                          'linkpred_collab_sage.yaml',
                          'nodepred_citeseer_gat.yaml',
                          'nodepred_citeseer_gcn.yaml',
                          'nodepred_citeseer_sage.yaml',
                          'nodepred_cora_gat.yaml',
                          'nodepred_cora_gcn.yaml',
                          'nodepred_cora_sage.yaml',
                          'nodepred_pubmed_gat.yaml',
                          'nodepred_pubmed_gcn.yaml',
                          'nodepred_pubmed_sage.yaml',
                          'nodepred-ns_arxiv_gcn.yaml',
                          'nodepred-ns_product_sage.yaml'])
def test_recipe(recipe):
    # Remove all generated yaml files
    current_dir = os.listdir("./")
    for item in current_dir:
        if item.endswith(".yaml"):
            os.remove(item)

    os.system('dgl recipe get {}'.format(recipe))
    assert os.path.exists(recipe)

def test_node_cora():
    os.system('dgl configure nodepred --data cora --model gcn')
    os.system('dgl train --cfg nodepred_cora_gcn.yaml')
    assert os.path.exists('checkpoint.pth')
    assert os.path.exists('model.pth')