Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
80d16efa
Unverified
Commit
80d16efa
authored
Jul 25, 2023
by
keli-wen
Committed by
GitHub
Jul 25, 2023
Browse files
[Graphbolt] Add `cat` optimization for UniformPick (#6030)
parent
327589c8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
33 deletions
+46
-33
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+46
-33
No files found.
graphbolt/src/csc_sampling_graph.cc
View file @
80d16efa
...
@@ -141,50 +141,63 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
...
@@ -141,50 +141,63 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
// If true, perform sampling for each edge type of each node, otherwise just
// If true, perform sampling for each edge type of each node, otherwise just
// sample once for each node with no regard of edge types.
// sample once for each node with no regard of edge types.
bool
consider_etype
=
(
fanouts
.
size
()
>
1
);
bool
consider_etype
=
(
fanouts
.
size
()
>
1
);
std
::
vector
<
torch
::
Tensor
>
picked_neighbors_per_node
(
num_nodes
);
const
int64_t
num_threads
=
torch
::
get_num_threads
();
std
::
vector
<
torch
::
Tensor
>
picked_neighbors_per_thread
(
num_threads
);
torch
::
Tensor
num_picked_neighbors_per_node
=
torch
::
Tensor
num_picked_neighbors_per_node
=
torch
::
zeros
({
num_nodes
+
1
},
indptr_
.
options
());
torch
::
zeros
({
num_nodes
+
1
},
indptr_
.
options
());
// Calculate GrainSize for parallel_for.
// Set the default grain size to 64.
const
int64_t
grain_size
=
64
;
AT_DISPATCH_INTEGRAL_TYPES
(
AT_DISPATCH_INTEGRAL_TYPES
(
indptr_
.
scalar_type
(),
"parallel_for"
,
([
&
]
{
indptr_
.
scalar_type
(),
"parallel_for"
,
([
&
]
{
torch
::
parallel_for
(
0
,
num_nodes
,
32
,
[
&
](
scalar_t
b
,
scalar_t
e
)
{
torch
::
parallel_for
(
const
scalar_t
*
indptr_data
=
indptr_
.
data_ptr
<
scalar_t
>
();
0
,
num_nodes
,
grain_size
,
[
&
](
scalar_t
begin
,
scalar_t
end
)
{
for
(
scalar_t
i
=
b
;
i
<
e
;
++
i
)
{
const
auto
indptr_options
=
indptr_
.
options
();
const
auto
nid
=
nodes
[
i
].
item
<
int64_t
>
();
const
scalar_t
*
indptr_data
=
indptr_
.
data_ptr
<
scalar_t
>
();
TORCH_CHECK
(
// Get current thread id.
nid
>=
0
&&
nid
<
NumNodes
(),
auto
thread_id
=
torch
::
get_thread_num
();
"The seed nodes' IDs should fall within the range of the "
int64_t
local_grain_size
=
end
-
begin
;
"graph's node IDs."
);
std
::
vector
<
torch
::
Tensor
>
picked_neighbors_cur_thread
(
const
auto
offset
=
indptr_data
[
nid
];
local_grain_size
);
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
if
(
num_neighbors
==
0
)
{
for
(
scalar_t
i
=
begin
;
i
<
end
;
++
i
)
{
// To avoid crashing during concatenation in the master thread,
const
auto
nid
=
nodes
[
i
].
item
<
int64_t
>
();
// initializing with empty tensors.
TORCH_CHECK
(
picked_neighbors_per_node
[
i
]
=
nid
>=
0
&&
nid
<
NumNodes
(),
torch
::
tensor
({},
indptr_
.
options
());
"The seed nodes' IDs should fall within the range of the "
continue
;
"graph's node IDs."
);
}
const
auto
offset
=
indptr_data
[
nid
];
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
if
(
consider_etype
)
{
if
(
num_neighbors
==
0
)
{
picked_neighbors_per_node
[
i
]
=
PickByEtype
(
// To avoid crashing during concatenation in the master
offset
,
num_neighbors
,
fanouts
,
replace
,
indptr_
.
options
(),
// thread, initializing with empty tensors.
type_per_edge_
.
value
(),
probs_or_mask
,
args
);
picked_neighbors_cur_thread
[
i
-
begin
]
=
}
else
{
torch
::
tensor
({},
indptr_options
);
picked_neighbors_per_node
[
i
]
=
Pick
(
continue
;
offset
,
num_neighbors
,
fanouts
[
0
],
replace
,
indptr_
.
options
(),
}
probs_or_mask
,
args
);
}
num_picked_neighbors_per_node
[
i
+
1
]
=
picked_neighbors_per_node
[
i
].
size
(
0
);
}
});
// End of the thread.
}));
if
(
consider_etype
)
{
picked_neighbors_cur_thread
[
i
-
begin
]
=
PickByEtype
(
offset
,
num_neighbors
,
fanouts
,
replace
,
indptr_options
,
type_per_edge_
.
value
(),
probs_or_mask
,
args
);
}
else
{
picked_neighbors_cur_thread
[
i
-
begin
]
=
Pick
(
offset
,
num_neighbors
,
fanouts
[
0
],
replace
,
indptr_options
,
probs_or_mask
,
args
);
}
num_picked_neighbors_per_node
[
i
+
1
]
=
picked_neighbors_cur_thread
[
i
-
begin
].
size
(
0
);
}
picked_neighbors_per_thread
[
thread_id
]
=
torch
::
cat
(
picked_neighbors_cur_thread
);
});
// End of parallel_for.
}));
torch
::
Tensor
subgraph_indptr
=
torch
::
Tensor
subgraph_indptr
=
torch
::
cumsum
(
num_picked_neighbors_per_node
,
0
);
torch
::
cumsum
(
num_picked_neighbors_per_node
,
0
);
torch
::
Tensor
picked_eids
=
torch
::
cat
(
picked_neighbors_per_
node
);
torch
::
Tensor
picked_eids
=
torch
::
cat
(
picked_neighbors_per_
thread
);
torch
::
Tensor
subgraph_indices
=
torch
::
Tensor
subgraph_indices
=
torch
::
index_select
(
indices_
,
0
,
picked_eids
);
torch
::
index_select
(
indices_
,
0
,
picked_eids
);
torch
::
optional
<
torch
::
Tensor
>
subgraph_type_per_edge
=
torch
::
nullopt
;
torch
::
optional
<
torch
::
Tensor
>
subgraph_type_per_edge
=
torch
::
nullopt
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment