bench_gcn.py 1.97 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Write the benchmarking functions here.
# See "Writing benchmarks" in the asv docs for more information.

import subprocess
import os
from pathlib import Path
import numpy as np
import tempfile

base_path = Path("~/regression/dgl/")


class GCNBenchmark:

    params = [['pytorch'], ['cora', 'pubmed'], ['0', '-1']]
    param_names = ['backend', 'dataset', 'gpu_id']
    timeout = 120

19
20
    def __init__(self):
        self.std_log = {}
21
22

    def setup(self, backend, dataset, gpu_id):
23
24
        key_name = "{}_{}_{}".format(backend, dataset, gpu_id)
        if key_name in self.std_log:
25
26
27
28
29
30
31
            return
        gcn_path = base_path / "examples/{}/gcn/train.py".format(backend)
        bashCommand = "/opt/conda/envs/{}-ci/bin/python {} --dataset {} --gpu {} --n-epochs 50".format(
            backend, gcn_path.expanduser(), dataset, gpu_id)
        process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE,env=dict(os.environ, DGLBACKEND=backend))
        output, error = process.communicate()
        print(str(error))
32
        self.std_log[key_name] = str(output)
33
34
35


    def track_gcn_time(self, backend, dataset, gpu_id):
36
37
38
        key_name = "{}_{}_{}".format(backend, dataset, gpu_id)
        lines = self.std_log[key_name].split("\\n")

39
40
41
42
43
44
45
46
47
48
        time_list = []
        for line in lines:
            # print(line)
            if 'Time' in line:
                time_str = line.strip().split('|')[1]
                time = float(time_str.split()[-1])
                time_list.append(time)
        return np.array(time_list)[-10:].mean()

    def track_gcn_accuracy(self, backend, dataset, gpu_id):
49
50
51
        key_name = "{}_{}_{}".format(backend, dataset, gpu_id)
        lines = self.std_log[key_name].split("\\n")

52
53
54
55
56
57
58
59
60
61
        test_acc = -1
        for line in lines:
            if 'Test accuracy' in line:
                test_acc = float(line.split()[-1][:-1])
                print(test_acc)
        return test_acc


GCNBenchmark.track_gcn_time.unit = 's'
GCNBenchmark.track_gcn_accuracy.unit = '%'