Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
2c03fe99
"git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "99b30a049ecfd98e37244672b1d8db0774ecb9b4"
Unverified
Commit
2c03fe99
authored
Jun 28, 2023
by
peizhou001
Committed by
GitHub
Jun 28, 2023
Browse files
[Graphbolt] Dispatch edge ids in neighbor sampling (#5889)
parent
2489f579
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
29 deletions
+34
-29
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+34
-29
No files found.
graphbolt/src/csc_sampling_graph.cc
View file @
2c03fe99
...
...
@@ -141,37 +141,42 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
torch
::
Tensor
num_picked_neighbors_per_node
=
torch
::
zeros
({
num_nodes
+
1
},
indptr_
.
options
());
torch
::
parallel_for
(
0
,
num_nodes
,
32
,
[
&
](
size_t
b
,
size_t
e
)
{
for
(
size_t
i
=
b
;
i
<
e
;
++
i
)
{
const
auto
nid
=
nodes
[
i
].
item
<
int64_t
>
();
TORCH_CHECK
(
nid
>=
0
&&
nid
<
NumNodes
(),
"The seed nodes' IDs should fall within the range of the graph's "
"node IDs."
);
const
auto
offset
=
indptr_
[
nid
].
item
<
int64_t
>
();
const
auto
num_neighbors
=
indptr_
[
nid
+
1
].
item
<
int64_t
>
()
-
offset
;
AT_DISPATCH_INTEGRAL_TYPES
(
indptr_
.
scalar_type
(),
"parallel_for"
,
([
&
]
{
torch
::
parallel_for
(
0
,
num_nodes
,
32
,
[
&
](
scalar_t
b
,
scalar_t
e
)
{
const
scalar_t
*
indptr_data
=
indptr_
.
data_ptr
<
scalar_t
>
();
for
(
scalar_t
i
=
b
;
i
<
e
;
++
i
)
{
const
auto
nid
=
nodes
[
i
].
item
<
int64_t
>
();
TORCH_CHECK
(
nid
>=
0
&&
nid
<
NumNodes
(),
"The seed nodes' IDs should fall within the range of the "
"graph's node IDs."
);
const
auto
offset
=
indptr_data
[
nid
];
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
if
(
num_neighbors
==
0
)
{
// Initialization is performed here because all tensors will be
// concatenated in the master thread, and having an undefined tensor
// during concatenation can result in a crash.
picked_neighbors_per_node
[
i
]
=
torch
::
tensor
({},
indptr_
.
options
());
continue
;
}
if
(
num_neighbors
==
0
)
{
// Initialization is performed here because all tensors will be
// concatenated in the master thread, and having an undefined
// tensor during concatenation can result in a crash.
picked_neighbors_per_node
[
i
]
=
torch
::
tensor
({},
indptr_
.
options
());
continue
;
}
if
(
consider_etype
)
{
picked_neighbors_per_node
[
i
]
=
PickByEtype
(
offset
,
num_neighbors
,
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
.
value
(),
probs_or_mask
);
}
else
{
picked_neighbors_per_node
[
i
]
=
Pick
(
offset
,
num_neighbors
,
fanouts
[
0
],
replace
,
indptr_
.
options
(),
probs_or_mask
);
}
num_picked_neighbors_per_node
[
i
+
1
]
=
picked_neighbors_per_node
[
i
].
size
(
0
);
}
});
// End of the thread.
if
(
consider_etype
)
{
picked_neighbors_per_node
[
i
]
=
PickByEtype
(
offset
,
num_neighbors
,
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
.
value
(),
probs_or_mask
);
}
else
{
picked_neighbors_per_node
[
i
]
=
Pick
(
offset
,
num_neighbors
,
fanouts
[
0
],
replace
,
indptr_
.
options
(),
probs_or_mask
);
}
num_picked_neighbors_per_node
[
i
+
1
]
=
picked_neighbors_per_node
[
i
].
size
(
0
);
}
});
// End of the thread.
}));
torch
::
Tensor
subgraph_indptr
=
torch
::
cumsum
(
num_picked_neighbors_per_node
,
0
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment