test_sampling_examples.py 962 Bytes
Newer Older
1
2
3
import os
import subprocess
import sys
4
import unittest
5
6
7
8
9
10
11
12
13
14
15
16

EXAMPLE_ROOT = os.path.join(
    os.path.dirname(os.path.relpath(__file__)),
    "..",
    "..",
    "examples",
    "sampling",
    "graphbolt",
    "quickstart",
)


17
18
def test_node_classification():
    script = os.path.join(EXAMPLE_ROOT, "node_classification.py")
19
    out = subprocess.run(["python", str(script)], capture_output=True)
20
21
22
    assert (
        out.returncode == 0
    ), f"stdout: {out.stdout.decode('utf-8')}\nstderr: {out.stderr.decode('utf-8')}"
23
    stdout = out.stdout.decode("utf-8")
24
    assert float(stdout[-5:]) > 0.60
25
26
27
28
29


def test_link_prediction():
    script = os.path.join(EXAMPLE_ROOT, "link_prediction.py")
    out = subprocess.run(["python", str(script)], capture_output=True)
30
31
32
    assert (
        out.returncode == 0
    ), f"stdout: {out.stdout.decode('utf-8')}\nstderr: {out.stderr.decode('utf-8')}"
33
34
    stdout = out.stdout.decode("utf-8")
    assert float(stdout[-5:]) > 0.80