Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
bc978736
Unverified
Commit
bc978736
authored
Mar 05, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Mar 05, 2024
Browse files
[GraphBolt] `torch.compile()` support for `gb.expand_indptr`. (#7188)
parent
8b266f50
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
87 additions
and
20 deletions
+87
-20
graphbolt/src/expand_indptr.cc
graphbolt/src/expand_indptr.cc
+15
-0
graphbolt/src/python_binding.cc
graphbolt/src/python_binding.cc
+11
-1
python/dgl/graphbolt/__init__.py
python/dgl/graphbolt/__init__.py
+21
-19
python/dgl/graphbolt/base.py
python/dgl/graphbolt/base.py
+13
-0
tests/python/pytorch/graphbolt/test_base.py
tests/python/pytorch/graphbolt/test_base.py
+27
-0
No files found.
graphbolt/src/expand_indptr.cc
View file @
bc978736
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
* @brief ExpandIndptr operators.
* @brief ExpandIndptr operators.
*/
*/
#include <graphbolt/cuda_ops.h>
#include <graphbolt/cuda_ops.h>
#include <torch/autograd.h>
#include "./macro.h"
#include "./macro.h"
#include "./utils.h"
#include "./utils.h"
...
@@ -29,5 +30,19 @@ torch::Tensor ExpandIndptr(
...
@@ -29,5 +30,19 @@ torch::Tensor ExpandIndptr(
indptr
.
diff
(),
0
,
output_size
);
indptr
.
diff
(),
0
,
output_size
);
}
}
TORCH_LIBRARY_IMPL
(
graphbolt
,
CPU
,
m
)
{
m
.
impl
(
"expand_indptr"
,
&
ExpandIndptr
);
}
#ifdef GRAPHBOLT_USE_CUDA
TORCH_LIBRARY_IMPL
(
graphbolt
,
CUDA
,
m
)
{
m
.
impl
(
"expand_indptr"
,
&
ExpandIndptrImpl
);
}
#endif
TORCH_LIBRARY_IMPL
(
graphbolt
,
Autograd
,
m
)
{
m
.
impl
(
"expand_indptr"
,
torch
::
autograd
::
autogradNotImplementedFallback
());
}
}
// namespace ops
}
// namespace ops
}
// namespace graphbolt
}
// namespace graphbolt
graphbolt/src/python_binding.cc
View file @
bc978736
...
@@ -88,11 +88,21 @@ TORCH_LIBRARY(graphbolt, m) {
...
@@ -88,11 +88,21 @@ TORCH_LIBRARY(graphbolt, m) {
m
.
def
(
"isin"
,
&
IsIn
);
m
.
def
(
"isin"
,
&
IsIn
);
m
.
def
(
"index_select"
,
&
ops
::
IndexSelect
);
m
.
def
(
"index_select"
,
&
ops
::
IndexSelect
);
m
.
def
(
"index_select_csc"
,
&
ops
::
IndexSelectCSC
);
m
.
def
(
"index_select_csc"
,
&
ops
::
IndexSelectCSC
);
m
.
def
(
"expand_indptr"
,
&
ops
::
ExpandIndptr
);
m
.
def
(
"set_seed"
,
&
RandomEngine
::
SetManualSeed
);
m
.
def
(
"set_seed"
,
&
RandomEngine
::
SetManualSeed
);
#ifdef GRAPHBOLT_USE_CUDA
#ifdef GRAPHBOLT_USE_CUDA
m
.
def
(
"set_max_uva_threads"
,
&
cuda
::
set_max_uva_threads
);
m
.
def
(
"set_max_uva_threads"
,
&
cuda
::
set_max_uva_threads
);
#endif
#endif
#ifdef HAS_IMPL_ABSTRACT_PYSTUB
m
.
impl_abstract_pystub
(
"dgl.graphbolt.base"
,
"//dgl.graphbolt.base"
);
#endif
m
.
def
(
"expand_indptr(Tensor indptr, ScalarType dtype, Tensor? node_ids, "
"SymInt? output_size) -> Tensor"
#ifdef HAS_PT2_COMPLIANT_TAG
,
{
at
::
Tag
::
pt2_compliant_tag
}
#endif
);
}
}
}
// namespace sampling
}
// namespace sampling
...
...
python/dgl/graphbolt/__init__.py
View file @
bc978736
...
@@ -5,25 +5,6 @@ import sys
...
@@ -5,25 +5,6 @@ import sys
import
torch
import
torch
from
.._ffi
import
libinfo
from
.._ffi
import
libinfo
from
.base
import
*
from
.minibatch
import
*
from
.dataloader
import
*
from
.dataset
import
*
from
.feature_fetcher
import
*
from
.feature_store
import
*
from
.impl
import
*
from
.itemset
import
*
from
.item_sampler
import
*
from
.minibatch_transformer
import
*
from
.negative_sampler
import
*
from
.sampled_subgraph
import
*
from
.subgraph_sampler
import
*
from
.internal
import
(
compact_csc_format
,
unique_and_compact
,
unique_and_compact_csc_formats
,
)
from
.utils
import
add_reverse_edges
,
add_reverse_edges_2
,
exclude_seed_edges
def
load_graphbolt
():
def
load_graphbolt
():
...
@@ -53,3 +34,24 @@ def load_graphbolt():
...
@@ -53,3 +34,24 @@ def load_graphbolt():
load_graphbolt
()
load_graphbolt
()
# pylint: disable=wrong-import-position
from
.base
import
*
from
.minibatch
import
*
from
.dataloader
import
*
from
.dataset
import
*
from
.feature_fetcher
import
*
from
.feature_store
import
*
from
.impl
import
*
from
.itemset
import
*
from
.item_sampler
import
*
from
.minibatch_transformer
import
*
from
.negative_sampler
import
*
from
.sampled_subgraph
import
*
from
.subgraph_sampler
import
*
from
.internal
import
(
compact_csc_format
,
unique_and_compact
,
unique_and_compact_csc_formats
,
)
from
.utils
import
add_reverse_edges
,
add_reverse_edges_2
,
exclude_seed_edges
python/dgl/graphbolt/base.py
View file @
bc978736
...
@@ -4,6 +4,7 @@ from collections import deque
...
@@ -4,6 +4,7 @@ from collections import deque
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
import
torch
import
torch
from
torch.torch_version
import
TorchVersion
from
torch.utils.data
import
functional_datapipe
from
torch.utils.data
import
functional_datapipe
from
torchdata.datapipes.iter
import
IterDataPipe
from
torchdata.datapipes.iter
import
IterDataPipe
...
@@ -63,6 +64,18 @@ def isin(elements, test_elements):
...
@@ -63,6 +64,18 @@ def isin(elements, test_elements):
return
torch
.
ops
.
graphbolt
.
isin
(
elements
,
test_elements
)
return
torch
.
ops
.
graphbolt
.
isin
(
elements
,
test_elements
)
if
TorchVersion
(
torch
.
__version__
)
>=
TorchVersion
(
"2.2.0a0"
):
@
torch
.
library
.
impl_abstract
(
"graphbolt::expand_indptr"
)
def
expand_indptr_abstract
(
indptr
,
dtype
,
node_ids
,
output_size
):
"""Abstract implementation of expand_indptr for torch.compile() support."""
if
output_size
is
None
:
output_size
=
torch
.
library
.
get_ctx
().
new_dynamic_size
()
if
dtype
is
None
:
dtype
=
node_ids
.
dtype
return
indptr
.
new_empty
(
output_size
,
dtype
=
dtype
)
def
expand_indptr
(
indptr
,
dtype
=
None
,
node_ids
=
None
,
output_size
=
None
):
def
expand_indptr
(
indptr
,
dtype
=
None
,
node_ids
=
None
,
output_size
=
None
):
"""Converts a given indptr offset tensor to a COO format tensor. If
"""Converts a given indptr offset tensor to a COO format tensor. If
node_ids is not given, it is assumed to be equal to
node_ids is not given, it is assumed to be equal to
...
...
tests/python/pytorch/graphbolt/test_base.py
View file @
bc978736
...
@@ -7,6 +7,7 @@ import backend as F
...
@@ -7,6 +7,7 @@ import backend as F
import
dgl.graphbolt
as
gb
import
dgl.graphbolt
as
gb
import
pytest
import
pytest
import
torch
import
torch
from
torch.torch_version
import
TorchVersion
from
.
import
gb_test_utils
from
.
import
gb_test_utils
...
@@ -296,6 +297,32 @@ def test_expand_indptr(nodes, dtype):
...
@@ -296,6 +297,32 @@ def test_expand_indptr(nodes, dtype):
gb_result
=
gb
.
expand_indptr
(
indptr
,
dtype
,
nodes
,
indptr
[
-
1
].
item
())
gb_result
=
gb
.
expand_indptr
(
indptr
,
dtype
,
nodes
,
indptr
[
-
1
].
item
())
assert
torch
.
equal
(
torch_result
,
gb_result
)
assert
torch
.
equal
(
torch_result
,
gb_result
)
if
TorchVersion
(
torch
.
__version__
)
>=
TorchVersion
(
"2.2.0a0"
):
import
torch._dynamo
as
dynamo
from
torch.testing._internal.optests
import
opcheck
# Tests torch.compile compatibility
for
output_size
in
[
None
,
indptr
[
-
1
].
item
()]:
kwargs
=
{
"node_ids"
:
nodes
,
"output_size"
:
output_size
}
opcheck
(
torch
.
ops
.
graphbolt
.
expand_indptr
,
(
indptr
,
dtype
),
kwargs
,
test_utils
=
[
"test_schema"
,
"test_autograd_registration"
,
"test_faketensor"
,
"test_aot_dispatch_dynamic"
,
],
raise_exception
=
True
,
)
explanation
=
dynamo
.
explain
(
gb
.
expand_indptr
)(
indptr
,
dtype
,
nodes
,
output_size
)
expected_breaks
=
-
1
if
output_size
is
None
else
0
assert
explanation
.
graph_break_count
==
expected_breaks
def
test_csc_format_base_representation
():
def
test_csc_format_base_representation
():
csc_format_base
=
gb
.
CSCFormatBase
(
csc_format_base
=
gb
.
CSCFormatBase
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment