"vscode:/vscode.git/clone" did not exist on "97c569694570e099c835e3c78a805613ac0777f5"
bench_sage.py 1.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Write the benchmarking functions here.
# See "Writing benchmarks" in the asv docs for more information.

import subprocess
import os
from pathlib import Path
import numpy as np
import tempfile

base_path = Path("~/regression/dgl/")


class SAGEBenchmark:

    params = [['pytorch'], ['0']]
    param_names = ['backend', 'gpu']
    timeout = 1800

19
20
    def __init__(self):
        self.std_log = {}
21
22

    def setup(self, backend, gpu):
23
24
        key_name = "{}_{}".format(backend, gpu)
        if key_name in self.std_log:
25
26
            return
        run_path = base_path / "examples/{}/graphsage/train_sampling.py".format(backend)
27
        bashCommand = "/opt/conda/envs/{}-ci/bin/python {} --num-workers=2 --num-epochs=16 --gpu={}".format(
28
29
30
31
            backend, run_path.expanduser(), gpu)
        process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE,env=dict(os.environ, DGLBACKEND=backend))
        output, error = process.communicate()
        print(str(error))
32
        self.std_log[key_name] = str(output)
33
34


35
36
37
    def track_sage_time(self, backend, gpu):
        key_name = key_name = "{}_{}".format(backend, gpu)
        lines = self.std_log[key_name].split("\\n")
38
39
40
41
42
43
44
        time_list = []
        for line in lines:
            if line.startswith('Epoch Time'):
                time_str = line.strip()[15:]
                time_list.append(float(time_str))
        return np.array(time_list).mean()

45
46
47
    def track_sage_accuracy(self, backend, gpu):
        key_name = key_name = "{}_{}".format(backend, gpu)
        lines = self.std_log[key_name].split("\\n")
48
49
50
51
52
53
54
55
56
57
        test_acc = 0.
        for line in lines:
            if line.startswith('Eval Acc'):
                acc_str = line.strip()[9:]
                test_acc = float(acc_str)
        return test_acc * 100


SAGEBenchmark.track_sage_time.unit = 's'
SAGEBenchmark.track_sage_accuracy.unit = '%'