__init__.py 1.27 KB
Newer Older
1
2
3
"""Graphbolt."""
import os
import sys
Rhett Ying's avatar
Rhett Ying committed
4

5
6
7
8
import torch

from .._ffi import libinfo
from .graph_storage import *
Rhett Ying's avatar
Rhett Ying committed
9
from .itemset import *
10
from .minibatch_sampler import *
11
from .feature_store import *
12
from .feature_fetcher import *
13
from .copy_to import *
14
from .dataset import *
15
from .subgraph_sampler import *
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44


def load_graphbolt():
    """Load Graphbolt C++ library"""
    version = torch.__version__.split("+", maxsplit=1)[0]

    if sys.platform.startswith("linux"):
        basename = f"libgraphbolt_pytorch_{version}.so"
    elif sys.platform.startswith("darwin"):
        basename = f"libgraphbolt_pytorch_{version}.dylib"
    elif sys.platform.startswith("win"):
        basename = f"graphbolt_pytorch_{version}.dll"
    else:
        raise NotImplementedError("Unsupported system: %s" % sys.platform)

    dirname = os.path.dirname(libinfo.find_lib_path()[0])
    path = os.path.join(dirname, "graphbolt", basename)
    if not os.path.exists(path):
        raise FileNotFoundError(
            f"Cannot find DGL C++ graphbolt library at {path}"
        )

    try:
        torch.classes.load_library(path)
    except Exception:  # pylint: disable=W0703
        raise ImportError("Cannot load Graphbolt C++ library")


load_graphbolt()
45
46

SampledSubgraph = torch.classes.graphbolt.SampledSubgraph