Unverified Commit bc978736 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt] `torch.compile()` support for `gb.expand_indptr`. (#7188)

parent 8b266f50
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
* @brief ExpandIndptr operators. * @brief ExpandIndptr operators.
*/ */
#include <graphbolt/cuda_ops.h> #include <graphbolt/cuda_ops.h>
#include <torch/autograd.h>
#include "./macro.h" #include "./macro.h"
#include "./utils.h" #include "./utils.h"
...@@ -29,5 +30,19 @@ torch::Tensor ExpandIndptr( ...@@ -29,5 +30,19 @@ torch::Tensor ExpandIndptr(
indptr.diff(), 0, output_size); indptr.diff(), 0, output_size);
} }
TORCH_LIBRARY_IMPL(graphbolt, CPU, m) {
m.impl("expand_indptr", &ExpandIndptr);
}
#ifdef GRAPHBOLT_USE_CUDA
TORCH_LIBRARY_IMPL(graphbolt, CUDA, m) {
m.impl("expand_indptr", &ExpandIndptrImpl);
}
#endif
TORCH_LIBRARY_IMPL(graphbolt, Autograd, m) {
m.impl("expand_indptr", torch::autograd::autogradNotImplementedFallback());
}
} // namespace ops } // namespace ops
} // namespace graphbolt } // namespace graphbolt
...@@ -88,11 +88,21 @@ TORCH_LIBRARY(graphbolt, m) { ...@@ -88,11 +88,21 @@ TORCH_LIBRARY(graphbolt, m) {
m.def("isin", &IsIn); m.def("isin", &IsIn);
m.def("index_select", &ops::IndexSelect); m.def("index_select", &ops::IndexSelect);
m.def("index_select_csc", &ops::IndexSelectCSC); m.def("index_select_csc", &ops::IndexSelectCSC);
m.def("expand_indptr", &ops::ExpandIndptr);
m.def("set_seed", &RandomEngine::SetManualSeed); m.def("set_seed", &RandomEngine::SetManualSeed);
#ifdef GRAPHBOLT_USE_CUDA #ifdef GRAPHBOLT_USE_CUDA
m.def("set_max_uva_threads", &cuda::set_max_uva_threads); m.def("set_max_uva_threads", &cuda::set_max_uva_threads);
#endif #endif
#ifdef HAS_IMPL_ABSTRACT_PYSTUB
m.impl_abstract_pystub("dgl.graphbolt.base", "//dgl.graphbolt.base");
#endif
m.def(
"expand_indptr(Tensor indptr, ScalarType dtype, Tensor? node_ids, "
"SymInt? output_size) -> Tensor"
#ifdef HAS_PT2_COMPLIANT_TAG
,
{at::Tag::pt2_compliant_tag}
#endif
);
} }
} // namespace sampling } // namespace sampling
......
...@@ -5,25 +5,6 @@ import sys ...@@ -5,25 +5,6 @@ import sys
import torch import torch
from .._ffi import libinfo from .._ffi import libinfo
from .base import *
from .minibatch import *
from .dataloader import *
from .dataset import *
from .feature_fetcher import *
from .feature_store import *
from .impl import *
from .itemset import *
from .item_sampler import *
from .minibatch_transformer import *
from .negative_sampler import *
from .sampled_subgraph import *
from .subgraph_sampler import *
from .internal import (
compact_csc_format,
unique_and_compact,
unique_and_compact_csc_formats,
)
from .utils import add_reverse_edges, add_reverse_edges_2, exclude_seed_edges
def load_graphbolt(): def load_graphbolt():
...@@ -53,3 +34,24 @@ def load_graphbolt(): ...@@ -53,3 +34,24 @@ def load_graphbolt():
load_graphbolt() load_graphbolt()
# pylint: disable=wrong-import-position
from .base import *
from .minibatch import *
from .dataloader import *
from .dataset import *
from .feature_fetcher import *
from .feature_store import *
from .impl import *
from .itemset import *
from .item_sampler import *
from .minibatch_transformer import *
from .negative_sampler import *
from .sampled_subgraph import *
from .subgraph_sampler import *
from .internal import (
compact_csc_format,
unique_and_compact,
unique_and_compact_csc_formats,
)
from .utils import add_reverse_edges, add_reverse_edges_2, exclude_seed_edges
...@@ -4,6 +4,7 @@ from collections import deque ...@@ -4,6 +4,7 @@ from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
from torch.torch_version import TorchVersion
from torch.utils.data import functional_datapipe from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe from torchdata.datapipes.iter import IterDataPipe
...@@ -63,6 +64,18 @@ def isin(elements, test_elements): ...@@ -63,6 +64,18 @@ def isin(elements, test_elements):
return torch.ops.graphbolt.isin(elements, test_elements) return torch.ops.graphbolt.isin(elements, test_elements)
if TorchVersion(torch.__version__) >= TorchVersion("2.2.0a0"):
@torch.library.impl_abstract("graphbolt::expand_indptr")
def expand_indptr_abstract(indptr, dtype, node_ids, output_size):
"""Abstract implementation of expand_indptr for torch.compile() support."""
if output_size is None:
output_size = torch.library.get_ctx().new_dynamic_size()
if dtype is None:
dtype = node_ids.dtype
return indptr.new_empty(output_size, dtype=dtype)
def expand_indptr(indptr, dtype=None, node_ids=None, output_size=None): def expand_indptr(indptr, dtype=None, node_ids=None, output_size=None):
"""Converts a given indptr offset tensor to a COO format tensor. If """Converts a given indptr offset tensor to a COO format tensor. If
node_ids is not given, it is assumed to be equal to node_ids is not given, it is assumed to be equal to
......
...@@ -7,6 +7,7 @@ import backend as F ...@@ -7,6 +7,7 @@ import backend as F
import dgl.graphbolt as gb import dgl.graphbolt as gb
import pytest import pytest
import torch import torch
from torch.torch_version import TorchVersion
from . import gb_test_utils from . import gb_test_utils
...@@ -296,6 +297,32 @@ def test_expand_indptr(nodes, dtype): ...@@ -296,6 +297,32 @@ def test_expand_indptr(nodes, dtype):
gb_result = gb.expand_indptr(indptr, dtype, nodes, indptr[-1].item()) gb_result = gb.expand_indptr(indptr, dtype, nodes, indptr[-1].item())
assert torch.equal(torch_result, gb_result) assert torch.equal(torch_result, gb_result)
if TorchVersion(torch.__version__) >= TorchVersion("2.2.0a0"):
import torch._dynamo as dynamo
from torch.testing._internal.optests import opcheck
# Tests torch.compile compatibility
for output_size in [None, indptr[-1].item()]:
kwargs = {"node_ids": nodes, "output_size": output_size}
opcheck(
torch.ops.graphbolt.expand_indptr,
(indptr, dtype),
kwargs,
test_utils=[
"test_schema",
"test_autograd_registration",
"test_faketensor",
"test_aot_dispatch_dynamic",
],
raise_exception=True,
)
explanation = dynamo.explain(gb.expand_indptr)(
indptr, dtype, nodes, output_size
)
expected_breaks = -1 if output_size is None else 0
assert explanation.graph_break_count == expected_breaks
def test_csc_format_base_representation(): def test_csc_format_base_representation():
csc_format_base = gb.CSCFormatBase( csc_format_base = gb.CSCFormatBase(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment