Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
2bc4df22
Unverified
Commit
2bc4df22
authored
Nov 23, 2023
by
Rhett Ying
Committed by
GitHub
Nov 23, 2023
Browse files
[GraphBolt] replace item<> with raw pointer (#6601)
parent
33e80452
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
99 additions
and
83 deletions
+99
-83
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+99
-83
No files found.
graphbolt/src/fused_csc_sampling_graph.cc
View file @
2bc4df22
...
...
@@ -330,96 +330,112 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
torch
::
optional
<
torch
::
Tensor
>
subgraph_type_per_edge
=
torch
::
nullopt
;
AT_DISPATCH_INTEGRAL_TYPES
(
indptr_
.
scalar_type
(),
"SampleNeighborsImpl"
,
([
&
]
{
const
scalar_t
*
indptr_data
=
indptr_
.
data_ptr
<
scalar_t
>
();
auto
num_picked_neighbors_data_ptr
=
num_picked_neighbors_per_node
.
data_ptr
<
scalar_t
>
();
num_picked_neighbors_data_ptr
[
0
]
=
0
;
indptr_
.
scalar_type
(),
"SampleNeighborsImplWrappedWithIndptr"
,
([
&
]
{
using
indptr_t
=
scalar_t
;
AT_DISPATCH_INTEGRAL_TYPES
(
nodes
.
scalar_type
(),
"SampleNeighborsImplWrappedWithNodes"
,
([
&
]
{
using
nodes_t
=
scalar_t
;
const
auto
indptr_data
=
indptr_
.
data_ptr
<
indptr_t
>
();
auto
num_picked_neighbors_data_ptr
=
num_picked_neighbors_per_node
.
data_ptr
<
indptr_t
>
();
num_picked_neighbors_data_ptr
[
0
]
=
0
;
const
auto
nodes_data_ptr
=
nodes
.
data_ptr
<
nodes_t
>
();
// Step 1. Calculate pick number of each node.
torch
::
parallel_for
(
0
,
num_nodes
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
const
auto
nid
=
nodes
[
i
].
item
<
int64_t
>
();
TORCH_CHECK
(
nid
>=
0
&&
nid
<
NumNodes
(),
"The seed nodes' IDs should fall within the range of the "
"graph's node IDs."
);
const
auto
offset
=
indptr_data
[
nid
];
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
// Step 1. Calculate pick number of each node.
torch
::
parallel_for
(
0
,
num_nodes
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
const
auto
nid
=
nodes_data_ptr
[
i
];
TORCH_CHECK
(
nid
>=
0
&&
nid
<
NumNodes
(),
"The seed nodes' IDs should fall within the range of "
"the "
"graph's node IDs."
);
const
auto
offset
=
indptr_data
[
nid
];
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
num_picked_neighbors_data_ptr
[
i
+
1
]
=
num_neighbors
==
0
?
0
:
num_pick_fn
(
offset
,
num_neighbors
);
}
});
num_picked_neighbors_data_ptr
[
i
+
1
]
=
num_neighbors
==
0
?
0
:
num_pick_fn
(
offset
,
num_neighbors
);
}
});
// Step 2. Calculate prefix sum to get total length and offsets of
each
//
node. It's also the indptr of the generated subgraph.
subgraph_indptr
=
num_picked_neighbors_per_node
.
cumsum
(
0
,
indptr_
.
scalar_type
());
// Step 2. Calculate prefix sum to get total length and offsets of
// each
node. It's also the indptr of the generated subgraph.
subgraph_indptr
=
num_picked_neighbors_per_node
.
cumsum
(
0
,
indptr_
.
scalar_type
());
// Step 3. Allocate the tensor for picked neighbors.
const
auto
total_length
=
subgraph_indptr
.
data_ptr
<
scalar_t
>
()[
num_nodes
];
picked_eids
=
torch
::
empty
({
total_length
},
indptr_options
);
subgraph_indices
=
torch
::
empty
({
total_length
},
indices_
.
options
());
if
(
type_per_edge_
.
has_value
())
{
subgraph_type_per_edge
=
torch
::
empty
({
total_length
},
type_per_edge_
.
value
().
options
());
}
// Step 3. Allocate the tensor for picked neighbors.
const
auto
total_length
=
subgraph_indptr
.
data_ptr
<
indptr_t
>
()[
num_nodes
];
picked_eids
=
torch
::
empty
({
total_length
},
indptr_options
);
subgraph_indices
=
torch
::
empty
({
total_length
},
indices_
.
options
());
if
(
type_per_edge_
.
has_value
())
{
subgraph_type_per_edge
=
torch
::
empty
(
{
total_length
},
type_per_edge_
.
value
().
options
());
}
// Step 4. Pick neighbors for each node.
auto
picked_eids_data_ptr
=
picked_eids
.
data_ptr
<
scalar_t
>
();
auto
subgraph_indptr_data_ptr
=
subgraph_indptr
.
data_ptr
<
scalar_t
>
();
torch
::
parallel_for
(
0
,
num_nodes
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
const
auto
nid
=
nodes
[
i
].
item
<
int64_t
>
();
const
auto
offset
=
indptr_data
[
nid
];
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
const
auto
picked_number
=
num_picked_neighbors_data_ptr
[
i
+
1
];
const
auto
picked_offset
=
subgraph_indptr_data_ptr
[
i
];
if
(
picked_number
>
0
)
{
auto
actual_picked_count
=
pick_fn
(
offset
,
num_neighbors
,
picked_eids_data_ptr
+
picked_offset
);
TORCH_CHECK
(
actual_picked_count
==
picked_number
,
"Actual picked count doesn't match the calculated pick "
"number."
);
// Step 4. Pick neighbors for each node.
auto
picked_eids_data_ptr
=
picked_eids
.
data_ptr
<
indptr_t
>
();
auto
subgraph_indptr_data_ptr
=
subgraph_indptr
.
data_ptr
<
indptr_t
>
();
torch
::
parallel_for
(
0
,
num_nodes
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
const
auto
nid
=
nodes_data_ptr
[
i
];
const
auto
offset
=
indptr_data
[
nid
];
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
const
auto
picked_number
=
num_picked_neighbors_data_ptr
[
i
+
1
];
const
auto
picked_offset
=
subgraph_indptr_data_ptr
[
i
];
if
(
picked_number
>
0
)
{
auto
actual_picked_count
=
pick_fn
(
offset
,
num_neighbors
,
picked_eids_data_ptr
+
picked_offset
);
TORCH_CHECK
(
actual_picked_count
==
picked_number
,
"Actual picked count doesn't match the calculated "
"pick "
"number."
);
// Step 5. Calculate other attributes and return the subgraph.
AT_DISPATCH_INTEGRAL_TYPES
(
subgraph_indices
.
scalar_type
(),
"IndexSelectSubgraphIndices"
,
([
&
]
{
auto
subgraph_indices_data_ptr
=
subgraph_indices
.
data_ptr
<
scalar_t
>
();
auto
indices_data_ptr
=
indices_
.
data_ptr
<
scalar_t
>
();
for
(
auto
i
=
picked_offset
;
i
<
picked_offset
+
picked_number
;
++
i
)
{
subgraph_indices_data_ptr
[
i
]
=
indices_data_ptr
[
picked_eids_data_ptr
[
i
]];
// Step 5. Calculate other attributes and return the
// subgraph.
AT_DISPATCH_INTEGRAL_TYPES
(
subgraph_indices
.
scalar_type
(),
"IndexSelectSubgraphIndices"
,
([
&
]
{
auto
subgraph_indices_data_ptr
=
subgraph_indices
.
data_ptr
<
scalar_t
>
();
auto
indices_data_ptr
=
indices_
.
data_ptr
<
scalar_t
>
();
for
(
auto
i
=
picked_offset
;
i
<
picked_offset
+
picked_number
;
++
i
)
{
subgraph_indices_data_ptr
[
i
]
=
indices_data_ptr
[
picked_eids_data_ptr
[
i
]];
}
}));
if
(
type_per_edge_
.
has_value
())
{
AT_DISPATCH_INTEGRAL_TYPES
(
subgraph_type_per_edge
.
value
().
scalar_type
(),
"IndexSelectTypePerEdge"
,
([
&
]
{
auto
subgraph_type_per_edge_data_ptr
=
subgraph_type_per_edge
.
value
()
.
data_ptr
<
scalar_t
>
();
auto
type_per_edge_data_ptr
=
type_per_edge_
.
value
().
data_ptr
<
scalar_t
>
();
for
(
auto
i
=
picked_offset
;
i
<
picked_offset
+
picked_number
;
++
i
)
{
subgraph_type_per_edge_data_ptr
[
i
]
=
type_per_edge_data_ptr
[
picked_eids_data_ptr
[
i
]];
}
}));
}
}));
if
(
type_per_edge_
.
has_value
())
{
AT_DISPATCH_INTEGRAL_TYPES
(
subgraph_type_per_edge
.
value
().
scalar_type
(),
"IndexSelectTypePerEdge"
,
([
&
]
{
auto
subgraph_type_per_edge_data_ptr
=
subgraph_type_per_edge
.
value
()
.
data_ptr
<
scalar_t
>
();
auto
type_per_edge_data_ptr
=
type_per_edge_
.
value
().
data_ptr
<
scalar_t
>
();
for
(
auto
i
=
picked_offset
;
i
<
picked_offset
+
picked_number
;
++
i
)
{
subgraph_type_per_edge_data_ptr
[
i
]
=
type_per_edge_data_ptr
[
picked_eids_data_ptr
[
i
]];
}
}));
}
}
}
});
}
}
});
}));
}));
torch
::
optional
<
torch
::
Tensor
>
subgraph_reverse_edge_ids
=
torch
::
nullopt
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment