Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
6b140f28
"...source/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "91b05e2ec78e44856d90f4258f91d56807227bac"
Unverified
Commit
6b140f28
authored
Apr 29, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Apr 29, 2024
Browse files
[GraphBolt] Hetero CPU sampling bug fix. (#7369)
parent
0d9a09df
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
36 deletions
+20
-36
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+20
-36
No files found.
graphbolt/src/fused_csc_sampling_graph.cc
View file @
6b140f28
...
@@ -557,7 +557,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
...
@@ -557,7 +557,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
// it equals to `num_seeds`.
// it equals to `num_seeds`.
const
int64_t
num_rows
=
etype_id_to_num_picked_offset
[
num_etypes
];
const
int64_t
num_rows
=
etype_id_to_num_picked_offset
[
num_etypes
];
torch
::
Tensor
num_picked_neighbors_per_node
=
torch
::
Tensor
num_picked_neighbors_per_node
=
torch
::
empty
({
num_rows
},
indptr_options
);
// Need to use zeros because all nodes don't have all etypes.
torch
::
zeros
({
num_rows
},
indptr_options
);
AT_DISPATCH_INDEX_TYPES
(
AT_DISPATCH_INDEX_TYPES
(
indptr_
.
scalar_type
(),
"SampleNeighborsImplWrappedWithIndptr"
,
([
&
]
{
indptr_
.
scalar_type
(),
"SampleNeighborsImplWrappedWithIndptr"
,
([
&
]
{
...
@@ -571,14 +572,6 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
...
@@ -571,14 +572,6 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
num_picked_neighbors_data_ptr
[
0
]
=
0
;
num_picked_neighbors_data_ptr
[
0
]
=
0
;
const
auto
seeds_data_ptr
=
seeds
.
data_ptr
<
seeds_t
>
();
const
auto
seeds_data_ptr
=
seeds
.
data_ptr
<
seeds_t
>
();
// Initialize the empty spots in `num_picked_neighbors_per_node`.
if
(
hetero_with_seed_offsets
)
{
for
(
auto
i
=
0
;
i
<
num_etypes
;
++
i
)
{
num_picked_neighbors_data_ptr
[
etype_id_to_num_picked_offset
[
i
]]
=
0
;
}
}
// Step 1. Calculate pick number of each node.
// Step 1. Calculate pick number of each node.
torch
::
parallel_for
(
torch
::
parallel_for
(
0
,
num_seeds
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
0
,
num_seeds
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
...
@@ -612,40 +605,36 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
...
@@ -612,40 +605,36 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
}
}
});
});
// Step 2. Calculate prefix sum to get total length and offsets of
// each node. It's also the indptr of the generated subgraph.
subgraph_indptr
=
num_picked_neighbors_per_node
.
cumsum
(
0
,
indptr_
.
scalar_type
());
auto
subgraph_indptr_data_ptr
=
subgraph_indptr
.
data_ptr
<
indptr_t
>
();
if
(
hetero_with_seed_offsets
)
{
if
(
hetero_with_seed_offsets
)
{
torch
::
Tensor
num_picked_offset_tensor
=
torch
::
Tensor
num_picked_offset_tensor
=
torch
::
zeros
({
num_etypes
+
1
},
indptr_options
);
torch
::
empty
({
num_etypes
+
1
},
indptr_options
);
const
auto
num_picked_offset_data_ptr
=
num_picked_offset_tensor
.
data_ptr
<
indptr_t
>
();
std
::
copy
(
etype_id_to_num_picked_offset
.
begin
(),
etype_id_to_num_picked_offset
.
end
(),
num_picked_offset_data_ptr
);
torch
::
Tensor
substract_offset
=
torch
::
Tensor
substract_offset
=
torch
::
zeros
({
num_etypes
},
indptr_options
);
torch
::
empty
({
num_etypes
},
indptr_options
);
const
auto
substract_offset_data_ptr
=
const
auto
substract_offset_data_ptr
=
substract_offset
.
data_ptr
<
indptr_t
>
();
substract_offset
.
data_ptr
<
indptr_t
>
();
const
auto
num_picked_offset_data_ptr
=
num_picked_offset_tensor
.
data_ptr
<
indptr_t
>
();
for
(
auto
i
=
0
;
i
<
num_etypes
;
++
i
)
{
for
(
auto
i
=
0
;
i
<
num_etypes
;
++
i
)
{
num_picked_offset_data_ptr
[
i
+
1
]
=
// Collect the total pick number subtract offsets.
etype_id_to_num_picked_offset
[
i
+
1
];
substract_offset_data_ptr
[
i
]
=
subgraph_indptr_data_ptr
// Collect the total pick number for each edge type.
[
etype_id_to_num_picked_offset
[
i
]];
if
(
i
+
1
<
num_etypes
)
substract_offset_data_ptr
[
i
+
1
]
=
num_picked_neighbors_data_ptr
[
etype_id_to_num_picked_offset
[
i
]];
num_picked_neighbors_data_ptr
[
etype_id_to_num_picked_offset
[
i
]]
=
0
;
}
}
substract_offset
=
substract_offset
.
cumsum
(
0
,
indptr_
.
scalar_type
());
subgraph_indptr_substract
=
ops
::
ExpandIndptr
(
subgraph_indptr_substract
=
ops
::
ExpandIndptr
(
num_picked_offset_tensor
,
indptr_
.
scalar_type
(),
num_picked_offset_tensor
,
indptr_
.
scalar_type
(),
substract_offset
);
substract_offset
);
}
}
// Step 2. Calculate prefix sum to get total length and offsets of
// each node. It's also the indptr of the generated subgraph.
subgraph_indptr
=
num_picked_neighbors_per_node
.
cumsum
(
0
,
indptr_
.
scalar_type
());
auto
subgraph_indptr_data_ptr
=
subgraph_indptr
.
data_ptr
<
indptr_t
>
();
// When doing non-temporal hetero sampling, we generate an
// When doing non-temporal hetero sampling, we generate an
// edge_offsets tensor.
// edge_offsets tensor.
if
(
hetero_with_seed_offsets
)
{
if
(
hetero_with_seed_offsets
)
{
...
@@ -1277,11 +1266,6 @@ void NumPickByEtype(
...
@@ -1277,11 +1266,6 @@ void NumPickByEtype(
NumPick
(
NumPick
(
fanouts
[
etype
],
replace
,
probs_or_mask
,
etype_begin
,
fanouts
[
etype
],
replace
,
probs_or_mask
,
etype_begin
,
etype_end
-
etype_begin
,
num_picked_ptr
+
offset
);
etype_end
-
etype_begin
,
num_picked_ptr
+
offset
);
// Use the skipped position of each edge type in the
// num_picked_tensor to sum up the total pick number for each edge
// type.
num_picked_ptr
[
etype_id_to_num_picked_offset
[
etype
]
-
1
]
+=
num_picked_ptr
[
offset
];
}
else
{
}
else
{
PickedNumType
picked_count
=
0
;
PickedNumType
picked_count
=
0
;
NumPick
(
NumPick
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment