Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
d2eca855
Unverified
Commit
d2eca855
authored
Oct 19, 2023
by
peizhou001
Committed by
GitHub
Oct 19, 2023
Browse files
[Graphbolt] Speed up exclude edges (#6464)
parent
72b3e078
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
160 additions
and
13 deletions
+160
-13
graphbolt/include/graphbolt/isin.h
graphbolt/include/graphbolt/isin.h
+36
-0
graphbolt/src/isin.cc
graphbolt/src/isin.cc
+45
-0
graphbolt/src/python_binding.cc
graphbolt/src/python_binding.cc
+2
-0
python/dgl/graphbolt/base.py
python/dgl/graphbolt/base.py
+24
-0
python/dgl/graphbolt/sampled_subgraph.py
python/dgl/graphbolt/sampled_subgraph.py
+26
-13
tests/python/pytorch/graphbolt/test_base.py
tests/python/pytorch/graphbolt/test_base.py
+27
-0
No files found.
graphbolt/include/graphbolt/isin.h
0 → 100644
View file @
d2eca855
/**
* Copyright (c) 2023 by Contributors
*
* @file graphbolt/isin.h
* @brief isin op.
*/
#ifndef GRAPHBOLT_ISIN_H_
#define GRAPHBOLT_ISIN_H_
#include <torch/torch.h>
namespace
graphbolt
{
namespace
sampling
{
/**
* @brief Tests if each element of elements is in test_elements. Returns a
* boolean tensor of the same shape as elements that is True for elements
* in test_elements and False otherwise. Enhance torch.isin by implementing
* multi-threaded searching, as detailed in the documentation at
* https://pytorch.org/docs/stable/generated/torch.isin.html."
*
* @param elements Input elements
* @param test_elements Values against which to test for each input element.
*
* @return
* A boolean tensor of the same shape as elements that is True for elements
* in test_elements and False otherwise.
*
*/
torch
::
Tensor
IsIn
(
const
torch
::
Tensor
&
elements
,
const
torch
::
Tensor
&
test_elements
);
}
// namespace sampling
}
// namespace graphbolt
#endif // GRAPHBOLT_ISIN_H_
graphbolt/src/isin.cc
0 → 100644
View file @
d2eca855
/**
* Copyright (c) 2023 by Contributors
*
* @file isin.cc
* @brief Isin op.
*/
#include <graphbolt/isin.h>
namespace
{
static
constexpr
int
kSearchGrainSize
=
4096
;
}
// namespace
namespace
graphbolt
{
namespace
sampling
{
torch
::
Tensor
IsIn
(
const
torch
::
Tensor
&
elements
,
const
torch
::
Tensor
&
test_elements
)
{
torch
::
Tensor
sorted_test_elements
;
std
::
tie
(
sorted_test_elements
,
std
::
ignore
)
=
test_elements
.
sort
(
/*stable=*/
false
,
/*dim=*/
0
,
/*descending=*/
false
);
torch
::
Tensor
result
=
torch
::
empty_like
(
elements
,
torch
::
kBool
);
size_t
num_test_elements
=
test_elements
.
size
(
0
);
size_t
num_elements
=
elements
.
size
(
0
);
AT_DISPATCH_INTEGRAL_TYPES
(
elements
.
scalar_type
(),
"IsInOperation"
,
([
&
]
{
const
scalar_t
*
elements_ptr
=
elements
.
data_ptr
<
scalar_t
>
();
const
scalar_t
*
sorted_test_elements_ptr
=
sorted_test_elements
.
data_ptr
<
scalar_t
>
();
bool
*
result_ptr
=
result
.
data_ptr
<
bool
>
();
torch
::
parallel_for
(
0
,
num_elements
,
kSearchGrainSize
,
[
&
](
size_t
start
,
size_t
end
)
{
for
(
auto
i
=
start
;
i
<
end
;
i
++
)
{
result_ptr
[
i
]
=
std
::
binary_search
(
sorted_test_elements_ptr
,
sorted_test_elements_ptr
+
num_test_elements
,
elements_ptr
[
i
]);
}
});
}));
return
result
;
}
}
// namespace sampling
}
// namespace graphbolt
graphbolt/src/python_binding.cc
View file @
d2eca855
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
*/
*/
#include <graphbolt/csc_sampling_graph.h>
#include <graphbolt/csc_sampling_graph.h>
#include <graphbolt/isin.h>
#include <graphbolt/serialize.h>
#include <graphbolt/serialize.h>
#include <graphbolt/unique_and_compact.h>
#include <graphbolt/unique_and_compact.h>
...
@@ -56,6 +57,7 @@ TORCH_LIBRARY(graphbolt, m) {
...
@@ -56,6 +57,7 @@ TORCH_LIBRARY(graphbolt, m) {
m
.
def
(
"save_csc_sampling_graph"
,
&
SaveCSCSamplingGraph
);
m
.
def
(
"save_csc_sampling_graph"
,
&
SaveCSCSamplingGraph
);
m
.
def
(
"load_from_shared_memory"
,
&
CSCSamplingGraph
::
LoadFromSharedMemory
);
m
.
def
(
"load_from_shared_memory"
,
&
CSCSamplingGraph
::
LoadFromSharedMemory
);
m
.
def
(
"unique_and_compact"
,
&
UniqueAndCompact
);
m
.
def
(
"unique_and_compact"
,
&
UniqueAndCompact
);
m
.
def
(
"isin"
,
&
IsIn
);
}
}
}
// namespace sampling
}
// namespace sampling
...
...
python/dgl/graphbolt/base.py
View file @
d2eca855
"""Base types and utilities for Graph Bolt."""
"""Base types and utilities for Graph Bolt."""
import
torch
from
torch.utils.data
import
functional_datapipe
from
torch.utils.data
import
functional_datapipe
from
torchdata.datapipes.iter
import
IterDataPipe
from
torchdata.datapipes.iter
import
IterDataPipe
...
@@ -11,12 +12,35 @@ __all__ = [
...
@@ -11,12 +12,35 @@ __all__ = [
"etype_str_to_tuple"
,
"etype_str_to_tuple"
,
"etype_tuple_to_str"
,
"etype_tuple_to_str"
,
"CopyTo"
,
"CopyTo"
,
"isin"
,
]
]
CANONICAL_ETYPE_DELIMITER
=
":"
CANONICAL_ETYPE_DELIMITER
=
":"
ORIGINAL_EDGE_ID
=
"_ORIGINAL_EDGE_ID"
ORIGINAL_EDGE_ID
=
"_ORIGINAL_EDGE_ID"
def
isin
(
elements
,
test_elements
):
"""Tests if each element of elements is in test_elements. Returns a boolean
tensor of the same shape as elements that is True for elements in
test_elements and False otherwise.
Parameters
----------
elements : torch.Tensor
A 1D tensor represents the input elements.
test_elements : torch.Tensor
A 1D tensor represents the values to test against for each input.
Examples
--------
>>> isin(torch.tensor([1, 2, 3, 4]), torch.tensor([2, 3]))
tensor([[False, True, True, False]])
"""
assert
elements
.
dim
()
==
1
,
"Elements should be 1D tensor."
assert
test_elements
.
dim
()
==
1
,
"Test_elements should be 1D tensor."
return
torch
.
ops
.
graphbolt
.
isin
(
elements
,
test_elements
)
def
etype_tuple_to_str
(
c_etype
):
def
etype_tuple_to_str
(
c_etype
):
"""Convert canonical etype from tuple to string.
"""Convert canonical etype from tuple to string.
...
...
python/dgl/graphbolt/sampled_subgraph.py
View file @
d2eca855
...
@@ -4,7 +4,7 @@ from typing import Dict, Tuple, Union
...
@@ -4,7 +4,7 @@ from typing import Dict, Tuple, Union
import
torch
import
torch
from
.base
import
etype_str_to_tuple
from
.base
import
etype_str_to_tuple
,
isin
__all__
=
[
"SampledSubgraph"
]
__all__
=
[
"SampledSubgraph"
]
...
@@ -85,6 +85,7 @@ class SampledSubgraph:
...
@@ -85,6 +85,7 @@ class SampledSubgraph:
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
],
],
assume_num_node_within_int32
:
bool
=
True
,
):
):
r
"""Exclude edges from the sampled subgraph.
r
"""Exclude edges from the sampled subgraph.
...
@@ -103,6 +104,10 @@ class SampledSubgraph:
...
@@ -103,6 +104,10 @@ class SampledSubgraph:
should be a pair of tensors representing the edges to exclude. If
should be a pair of tensors representing the edges to exclude. If
sampled subgraph is heterogeneous, then `edges` should be a dictionary
sampled subgraph is heterogeneous, then `edges` should be a dictionary
of edge types and the corresponding edges to exclude.
of edge types and the corresponding edges to exclude.
assume_num_node_within_int32: bool
If True, assumes the value of node IDs in the provided `edges` fall
within the int32 range, which can significantly enhance computation
speed. Default: True
Returns
Returns
-------
-------
...
@@ -133,6 +138,10 @@ class SampledSubgraph:
...
@@ -133,6 +138,10 @@ class SampledSubgraph:
>>> print(result.original_edge_ids)
>>> print(result.original_edge_ids)
{"A:relation:B": tensor([19])}
{"A:relation:B": tensor([19])}
"""
"""
# TODO: Add support for value > in32, then remove this line.
assert
(
assume_num_node_within_int32
),
"Values > int32 are not supported yet."
assert
isinstance
(
self
.
node_pairs
,
tuple
)
==
isinstance
(
edges
,
tuple
),
(
assert
isinstance
(
self
.
node_pairs
,
tuple
)
==
isinstance
(
edges
,
tuple
),
(
"The sampled subgraph and the edges to exclude should be both "
"The sampled subgraph and the edges to exclude should be both "
"homogeneous or both heterogeneous."
"homogeneous or both heterogeneous."
...
@@ -150,7 +159,9 @@ class SampledSubgraph:
...
@@ -150,7 +159,9 @@ class SampledSubgraph:
self
.
original_row_node_ids
,
self
.
original_row_node_ids
,
self
.
original_column_node_ids
,
self
.
original_column_node_ids
,
)
)
index
=
_exclude_homo_edges
(
reverse_edges
,
edges
)
index
=
_exclude_homo_edges
(
reverse_edges
,
edges
,
assume_num_node_within_int32
)
return
calling_class
(
*
_slice_subgraph
(
self
,
index
))
return
calling_class
(
*
_slice_subgraph
(
self
,
index
))
else
:
else
:
index
=
{}
index
=
{}
...
@@ -172,7 +183,9 @@ class SampledSubgraph:
...
@@ -172,7 +183,9 @@ class SampledSubgraph:
original_column_node_ids
,
original_column_node_ids
,
)
)
index
[
etype
]
=
_exclude_homo_edges
(
index
[
etype
]
=
_exclude_homo_edges
(
reverse_edges
,
edges
.
get
(
etype
)
reverse_edges
,
edges
.
get
(
etype
),
assume_num_node_within_int32
,
)
)
return
calling_class
(
*
_slice_subgraph
(
self
,
index
))
return
calling_class
(
*
_slice_subgraph
(
self
,
index
))
...
@@ -193,17 +206,17 @@ def _relabel_two_arrays(lhs_array, rhs_array):
...
@@ -193,17 +206,17 @@ def _relabel_two_arrays(lhs_array, rhs_array):
return
mapping
[:
lhs_array
.
numel
()],
mapping
[
lhs_array
.
numel
()
:]
return
mapping
[:
lhs_array
.
numel
()],
mapping
[
lhs_array
.
numel
()
:]
def
_exclude_homo_edges
(
edges
,
edges_to_exclude
):
def
_exclude_homo_edges
(
edges
,
edges_to_exclude
,
assume_num_node_within_int32
):
"""Return the indices of edges that are not in edges_to_exclude."""
"""Return the indices of edges that are not in edges_to_exclude."""
# 1. Relabel edges.
if
assume_num_node_within_int32
:
src
,
src_to_exclude
=
_relabel_two_arrays
(
edges
[
0
],
edges_to_exclude
[
0
])
val
=
edges
[
0
]
<<
32
|
edges
[
1
]
dst
,
dst
_to_exclude
=
_relabel
_t
w
o_
arrays
(
edges
[
1
],
edges_to_exclude
[
1
]
)
val
_to_exclude
=
edges
_to_
exclude
[
0
]
<<
32
|
edges_to_exclude
[
1
]
# 2. Compact the edges to integers.
else
:
dst_max_range
=
dst
.
numel
()
+
dst_to_exclude
.
numel
()
# TODO: Add support for value > int32.
val
=
src
*
dst_max_range
+
dst
raise
NotImplementedError
(
val_to_exclude
=
src_to_exclude
*
dst_max_range
+
dst_to_exclude
"Values out of range int32 are not supported yet"
# 3. Use torch.isin to get the indices of edges to keep.
)
mask
=
~
torch
.
isin
(
val
,
val_to_exclude
)
mask
=
~
isin
(
val
,
val_to_exclude
)
return
torch
.
nonzero
(
mask
,
as_tuple
=
True
)[
0
]
return
torch
.
nonzero
(
mask
,
as_tuple
=
True
)[
0
]
...
...
tests/python/pytorch/graphbolt/test_base.py
View file @
d2eca855
...
@@ -123,3 +123,30 @@ def test_etype_str_to_tuple():
...
@@ -123,3 +123,30 @@ def test_etype_str_to_tuple():
),
),
):
):
_
=
gb
.
etype_str_to_tuple
(
c_etype_str
)
_
=
gb
.
etype_str_to_tuple
(
c_etype_str
)
def
test_isin
():
elements
=
torch
.
tensor
([
2
,
3
,
5
,
5
,
20
,
13
,
11
])
test_elements
=
torch
.
tensor
([
2
,
5
])
res
=
gb
.
isin
(
elements
,
test_elements
)
expected
=
torch
.
tensor
([
True
,
False
,
True
,
True
,
False
,
False
,
False
])
assert
torch
.
equal
(
res
,
expected
)
def
test_isin_big_data
():
elements
=
torch
.
randint
(
0
,
10000
,
(
10000000
,))
test_elements
=
torch
.
randint
(
0
,
10000
,
(
500000
,))
res
=
gb
.
isin
(
elements
,
test_elements
)
expected
=
torch
.
isin
(
elements
,
test_elements
)
assert
torch
.
equal
(
res
,
expected
)
def
test_isin_non_1D_dim
():
elements
=
torch
.
tensor
([[
2
,
3
],
[
5
,
5
],
[
20
,
13
]])
test_elements
=
torch
.
tensor
([
2
,
5
])
with
pytest
.
raises
(
Exception
):
gb
.
isin
(
elements
,
test_elements
)
elements
=
torch
.
tensor
([
2
,
3
,
5
,
5
,
20
,
13
])
test_elements
=
torch
.
tensor
([[
2
,
5
]])
with
pytest
.
raises
(
Exception
):
gb
.
isin
(
elements
,
test_elements
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment