Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
eb434893
Unverified
Commit
eb434893
authored
Nov 10, 2023
by
Rhett Ying
Committed by
GitHub
Nov 10, 2023
Browse files
[GraphBolt] fix incorrect indptr of in_subgraph (#6555)
parent
b35757a0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
37 additions
and
33 deletions
+37
-33
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+25
-22
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+2
-1
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
...n/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
+10
-10
No files found.
graphbolt/src/fused_csc_sampling_graph.cc
View file @
eb434893
...
...
@@ -186,33 +186,36 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
const
torch
::
Tensor
&
nodes
)
const
{
using
namespace
torch
::
indexing
;
const
int32_t
kDefaultGrainSize
=
100
;
torch
::
Tensor
indptr
=
torch
::
zeros_like
(
indptr_
);
const
size_t
num_seeds
=
nodes
.
siz
e
(
0
);
const
auto
num_seeds
=
nodes
.
size
(
0
);
torch
::
Tensor
indptr
=
torch
::
zeros
({
num_seeds
+
1
},
indptr_
.
dtyp
e
(
)
);
std
::
vector
<
torch
::
Tensor
>
indices_arr
(
num_seeds
);
torch
::
Tensor
original_column_node_ids
=
torch
::
zeros
({
num_seeds
},
indptr_
.
dtype
());
std
::
vector
<
torch
::
Tensor
>
edge_ids_arr
(
num_seeds
);
std
::
vector
<
torch
::
Tensor
>
type_per_edge_arr
(
num_seeds
);
torch
::
parallel_for
(
0
,
num_seeds
,
kDefaultGrainSize
,
[
&
](
size_t
start
,
size_t
end
)
{
for
(
size_t
i
=
start
;
i
<
end
;
++
i
)
{
const
int64_t
node_id
=
nodes
[
i
].
item
<
int64_t
>
();
const
int64_t
start_idx
=
indptr_
[
node_id
].
item
<
int64_t
>
();
const
int64_t
end_idx
=
indptr_
[
node_id
+
1
].
item
<
int64_t
>
();
indptr
[
node_id
+
1
]
=
end_idx
-
start_idx
;
indices_arr
[
i
]
=
indices_
.
slice
(
0
,
start_idx
,
end_idx
);
edge_ids_arr
[
i
]
=
torch
::
arange
(
start_idx
,
end_idx
);
if
(
type_per_edge_
)
{
type_per_edge_arr
[
i
]
=
type_per_edge_
.
value
().
slice
(
0
,
start_idx
,
end_idx
);
}
}
});
const
auto
&
nonzero_idx
=
torch
::
nonzero
(
indptr
).
reshape
(
-
1
);
torch
::
Tensor
compact_indptr
=
torch
::
zeros
({
nonzero_idx
.
size
(
0
)
+
1
},
indptr_
.
dtype
());
compact_indptr
.
index_put_
({
Slice
(
1
,
None
)},
indptr
.
index
({
nonzero_idx
}));
AT_DISPATCH_INTEGRAL_TYPES
(
indptr_
.
scalar_type
(),
"InSubgraph"
,
([
&
]
{
torch
::
parallel_for
(
0
,
num_seeds
,
kDefaultGrainSize
,
[
&
](
size_t
start
,
size_t
end
)
{
for
(
size_t
i
=
start
;
i
<
end
;
++
i
)
{
const
auto
node_id
=
nodes
[
i
].
item
<
scalar_t
>
();
const
auto
start_idx
=
indptr_
[
node_id
].
item
<
scalar_t
>
();
const
auto
end_idx
=
indptr_
[
node_id
+
1
].
item
<
scalar_t
>
();
indptr
[
i
+
1
]
=
end_idx
-
start_idx
;
original_column_node_ids
[
i
]
=
node_id
;
indices_arr
[
i
]
=
indices_
.
slice
(
0
,
start_idx
,
end_idx
);
edge_ids_arr
[
i
]
=
torch
::
arange
(
start_idx
,
end_idx
);
if
(
type_per_edge_
)
{
type_per_edge_arr
[
i
]
=
type_per_edge_
.
value
().
slice
(
0
,
start_idx
,
end_idx
);
}
}
});
}));
return
c10
::
make_intrusive
<
FusedSampledSubgraph
>
(
compact_
indptr
.
cumsum
(
0
),
torch
::
cat
(
indices_arr
),
nonzero_idx
-
1
,
indptr
.
cumsum
(
0
),
torch
::
cat
(
indices_arr
),
original_column_node_ids
,
torch
::
arange
(
0
,
NumNodes
()),
torch
::
cat
(
edge_ids_arr
),
type_per_edge_
?
torch
::
optional
<
torch
::
Tensor
>
{
torch
::
cat
(
type_per_edge_arr
)}
...
...
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
eb434893
...
...
@@ -289,7 +289,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""Return the subgraph induced on the inbound edges of the given nodes.
An in subgraph is equivalent to creating a new graph using the incoming
edges of the given nodes.
edges of the given nodes. Subgraph is compacted according to the order
of passed-in `nodes`.
Parameters
----------
...
...
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
View file @
eb434893
...
...
@@ -497,20 +497,20 @@ def test_in_subgraph_homogeneous():
graph
=
gb
.
from_fused_csc
(
indptr
,
indices
)
# Extract in subgraph.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
nodes
=
torch
.
LongTensor
([
4
,
1
,
3
])
in_subgraph
=
graph
.
in_subgraph
(
nodes
)
# Verify in subgraph.
assert
torch
.
equal
(
in_subgraph
.
node_pairs
[
0
],
torch
.
LongTensor
([
2
,
3
,
1
,
2
,
0
,
3
,
4
])
in_subgraph
.
node_pairs
[
0
],
torch
.
LongTensor
([
0
,
3
,
4
,
2
,
3
,
1
,
2
])
)
assert
torch
.
equal
(
in_subgraph
.
node_pairs
[
1
],
torch
.
LongTensor
([
1
,
1
,
3
,
3
,
4
,
4
,
4
])
in_subgraph
.
node_pairs
[
1
],
torch
.
LongTensor
([
4
,
4
,
4
,
1
,
1
,
3
,
3
])
)
assert
in_subgraph
.
original_column_node_ids
is
None
assert
in_subgraph
.
original_row_node_ids
is
None
assert
torch
.
equal
(
in_subgraph
.
original_edge_ids
,
torch
.
LongTensor
([
3
,
4
,
7
,
8
,
9
,
10
,
11
])
in_subgraph
.
original_edge_ids
,
torch
.
LongTensor
([
9
,
10
,
11
,
3
,
4
,
7
,
8
])
)
...
...
@@ -564,7 +564,7 @@ def test_in_subgraph_heterogeneous():
# Extract in subgraph.
nodes
=
{
"N0"
:
torch
.
LongTensor
([
1
]),
"N1"
:
torch
.
LongTensor
([
1
,
2
]),
"N1"
:
torch
.
LongTensor
([
2
,
1
]),
}
in_subgraph
=
graph
.
in_subgraph
(
nodes
)
...
...
@@ -576,10 +576,10 @@ def test_in_subgraph_heterogeneous():
in_subgraph
.
node_pairs
[
"N0:R0:N0"
][
1
],
torch
.
LongTensor
([])
)
assert
torch
.
equal
(
in_subgraph
.
node_pairs
[
"N0:R1:N1"
][
0
],
torch
.
LongTensor
([
1
,
0
])
in_subgraph
.
node_pairs
[
"N0:R1:N1"
][
0
],
torch
.
LongTensor
([
0
,
1
])
)
assert
torch
.
equal
(
in_subgraph
.
node_pairs
[
"N0:R1:N1"
][
1
],
torch
.
LongTensor
([
1
,
2
])
in_subgraph
.
node_pairs
[
"N0:R1:N1"
][
1
],
torch
.
LongTensor
([
2
,
1
])
)
assert
torch
.
equal
(
in_subgraph
.
node_pairs
[
"N1:R2:N0"
][
0
],
torch
.
LongTensor
([
0
,
1
])
...
...
@@ -588,15 +588,15 @@ def test_in_subgraph_heterogeneous():
in_subgraph
.
node_pairs
[
"N1:R2:N0"
][
1
],
torch
.
LongTensor
([
1
,
1
])
)
assert
torch
.
equal
(
in_subgraph
.
node_pairs
[
"N1:R3:N1"
][
0
],
torch
.
LongTensor
([
0
,
1
,
2
])
in_subgraph
.
node_pairs
[
"N1:R3:N1"
][
0
],
torch
.
LongTensor
([
1
,
2
,
0
])
)
assert
torch
.
equal
(
in_subgraph
.
node_pairs
[
"N1:R3:N1"
][
1
],
torch
.
LongTensor
([
1
,
2
,
2
])
in_subgraph
.
node_pairs
[
"N1:R3:N1"
][
1
],
torch
.
LongTensor
([
2
,
2
,
1
])
)
assert
in_subgraph
.
original_column_node_ids
is
None
assert
in_subgraph
.
original_row_node_ids
is
None
assert
torch
.
equal
(
in_subgraph
.
original_edge_ids
,
torch
.
LongTensor
([
3
,
4
,
7
,
8
,
9
,
10
,
11
])
in_subgraph
.
original_edge_ids
,
torch
.
LongTensor
([
3
,
4
,
9
,
10
,
11
,
7
,
8
])
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment