test_sampling_examples.py 852 Bytes
Newer Older
1
2
3
import os
import subprocess
import sys
4
import unittest
5
6
7
8
9
10
11
12
13
14
15
16

EXAMPLE_ROOT = os.path.join(
    os.path.dirname(os.path.relpath(__file__)),
    "..",
    "..",
    "examples",
    "sampling",
    "graphbolt",
    "quickstart",
)


17
18
def test_node_classification():
    script = os.path.join(EXAMPLE_ROOT, "node_classification.py")
19
20
21
    out = subprocess.run(["python", str(script)], capture_output=True)
    assert out.returncode == 0
    stdout = out.stdout.decode("utf-8")
22
    assert float(stdout[-5:]) > 0.60
23
24
25
26
27
28
29
30
31


@unittest.skipIf(os.name == "nt", reason="TODO(6575): Fix the test on Windows")
def test_link_prediction():
    script = os.path.join(EXAMPLE_ROOT, "link_prediction.py")
    out = subprocess.run(["python", str(script)], capture_output=True)
    assert out.returncode == 0
    stdout = out.stdout.decode("utf-8")
    assert float(stdout[-5:]) > 0.80