Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
2bc4df22
Unverified
Commit
2bc4df22
authored
Nov 23, 2023
by
Rhett Ying
Committed by
GitHub
Nov 23, 2023
Browse files
[GraphBolt] replace item<> with raw pointer (#6601)
parent
33e80452
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
99 additions
and
83 deletions
+99
-83
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+99
-83
No files found.
graphbolt/src/fused_csc_sampling_graph.cc
View file @
2bc4df22
...
@@ -330,96 +330,112 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
...
@@ -330,96 +330,112 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
torch
::
optional
<
torch
::
Tensor
>
subgraph_type_per_edge
=
torch
::
nullopt
;
torch
::
optional
<
torch
::
Tensor
>
subgraph_type_per_edge
=
torch
::
nullopt
;
AT_DISPATCH_INTEGRAL_TYPES
(
AT_DISPATCH_INTEGRAL_TYPES
(
indptr_
.
scalar_type
(),
"SampleNeighborsImpl"
,
([
&
]
{
indptr_
.
scalar_type
(),
"SampleNeighborsImplWrappedWithIndptr"
,
([
&
]
{
const
scalar_t
*
indptr_data
=
indptr_
.
data_ptr
<
scalar_t
>
();
using
indptr_t
=
scalar_t
;
auto
num_picked_neighbors_data_ptr
=
AT_DISPATCH_INTEGRAL_TYPES
(
num_picked_neighbors_per_node
.
data_ptr
<
scalar_t
>
();
nodes
.
scalar_type
(),
"SampleNeighborsImplWrappedWithNodes"
,
([
&
]
{
num_picked_neighbors_data_ptr
[
0
]
=
0
;
using
nodes_t
=
scalar_t
;
const
auto
indptr_data
=
indptr_
.
data_ptr
<
indptr_t
>
();
auto
num_picked_neighbors_data_ptr
=
num_picked_neighbors_per_node
.
data_ptr
<
indptr_t
>
();
num_picked_neighbors_data_ptr
[
0
]
=
0
;
const
auto
nodes_data_ptr
=
nodes
.
data_ptr
<
nodes_t
>
();
// Step 1. Calculate pick number of each node.
// Step 1. Calculate pick number of each node.
torch
::
parallel_for
(
torch
::
parallel_for
(
0
,
num_nodes
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
0
,
num_nodes
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
const
auto
nid
=
nodes
[
i
].
item
<
int64_t
>
();
const
auto
nid
=
nodes_data_ptr
[
i
];
TORCH_CHECK
(
TORCH_CHECK
(
nid
>=
0
&&
nid
<
NumNodes
(),
nid
>=
0
&&
nid
<
NumNodes
(),
"The seed nodes' IDs should fall within the range of the "
"The seed nodes' IDs should fall within the range of "
"graph's node IDs."
);
"the "
const
auto
offset
=
indptr_data
[
nid
];
"graph's node IDs."
);
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
const
auto
offset
=
indptr_data
[
nid
];
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
num_picked_neighbors_data_ptr
[
i
+
1
]
=
num_picked_neighbors_data_ptr
[
i
+
1
]
=
num_neighbors
==
0
?
0
:
num_pick_fn
(
offset
,
num_neighbors
);
num_neighbors
==
0
}
?
0
});
:
num_pick_fn
(
offset
,
num_neighbors
);
}
});
// Step 2. Calculate prefix sum to get total length and offsets of
each
// Step 2. Calculate prefix sum to get total length and offsets of
//
node. It's also the indptr of the generated subgraph.
// each
node. It's also the indptr of the generated subgraph.
subgraph_indptr
=
subgraph_indptr
=
num_picked_neighbors_per_node
.
cumsum
(
num_picked_neighbors_per_node
.
cumsum
(
0
,
indptr_
.
scalar_type
());
0
,
indptr_
.
scalar_type
());
// Step 3. Allocate the tensor for picked neighbors.
// Step 3. Allocate the tensor for picked neighbors.
const
auto
total_length
=
const
auto
total_length
=
subgraph_indptr
.
data_ptr
<
scalar_t
>
()[
num_nodes
];
subgraph_indptr
.
data_ptr
<
indptr_t
>
()[
num_nodes
];
picked_eids
=
torch
::
empty
({
total_length
},
indptr_options
);
picked_eids
=
torch
::
empty
({
total_length
},
indptr_options
);
subgraph_indices
=
torch
::
empty
({
total_length
},
indices_
.
options
());
subgraph_indices
=
if
(
type_per_edge_
.
has_value
())
{
torch
::
empty
({
total_length
},
indices_
.
options
());
subgraph_type_per_edge
=
if
(
type_per_edge_
.
has_value
())
{
torch
::
empty
({
total_length
},
type_per_edge_
.
value
().
options
());
subgraph_type_per_edge
=
torch
::
empty
(
}
{
total_length
},
type_per_edge_
.
value
().
options
());
}
// Step 4. Pick neighbors for each node.
// Step 4. Pick neighbors for each node.
auto
picked_eids_data_ptr
=
picked_eids
.
data_ptr
<
scalar_t
>
();
auto
picked_eids_data_ptr
=
picked_eids
.
data_ptr
<
indptr_t
>
();
auto
subgraph_indptr_data_ptr
=
subgraph_indptr
.
data_ptr
<
scalar_t
>
();
auto
subgraph_indptr_data_ptr
=
torch
::
parallel_for
(
subgraph_indptr
.
data_ptr
<
indptr_t
>
();
0
,
num_nodes
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
torch
::
parallel_for
(
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
0
,
num_nodes
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
const
auto
nid
=
nodes
[
i
].
item
<
int64_t
>
();
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
const
auto
offset
=
indptr_data
[
nid
];
const
auto
nid
=
nodes_data_ptr
[
i
];
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
const
auto
offset
=
indptr_data
[
nid
];
const
auto
picked_number
=
num_picked_neighbors_data_ptr
[
i
+
1
];
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
const
auto
picked_offset
=
subgraph_indptr_data_ptr
[
i
];
const
auto
picked_number
=
if
(
picked_number
>
0
)
{
num_picked_neighbors_data_ptr
[
i
+
1
];
auto
actual_picked_count
=
pick_fn
(
const
auto
picked_offset
=
subgraph_indptr_data_ptr
[
i
];
offset
,
num_neighbors
,
if
(
picked_number
>
0
)
{
picked_eids_data_ptr
+
picked_offset
);
auto
actual_picked_count
=
pick_fn
(
TORCH_CHECK
(
offset
,
num_neighbors
,
actual_picked_count
==
picked_number
,
picked_eids_data_ptr
+
picked_offset
);
"Actual picked count doesn't match the calculated pick "
TORCH_CHECK
(
"number."
);
actual_picked_count
==
picked_number
,
"Actual picked count doesn't match the calculated "
"pick "
"number."
);
// Step 5. Calculate other attributes and return the subgraph.
// Step 5. Calculate other attributes and return the
AT_DISPATCH_INTEGRAL_TYPES
(
// subgraph.
subgraph_indices
.
scalar_type
(),
AT_DISPATCH_INTEGRAL_TYPES
(
"IndexSelectSubgraphIndices"
,
([
&
]
{
subgraph_indices
.
scalar_type
(),
auto
subgraph_indices_data_ptr
=
"IndexSelectSubgraphIndices"
,
([
&
]
{
subgraph_indices
.
data_ptr
<
scalar_t
>
();
auto
subgraph_indices_data_ptr
=
auto
indices_data_ptr
=
indices_
.
data_ptr
<
scalar_t
>
();
subgraph_indices
.
data_ptr
<
scalar_t
>
();
for
(
auto
i
=
picked_offset
;
auto
indices_data_ptr
=
i
<
picked_offset
+
picked_number
;
++
i
)
{
indices_
.
data_ptr
<
scalar_t
>
();
subgraph_indices_data_ptr
[
i
]
=
for
(
auto
i
=
picked_offset
;
indices_data_ptr
[
picked_eids_data_ptr
[
i
]];
i
<
picked_offset
+
picked_number
;
++
i
)
{
subgraph_indices_data_ptr
[
i
]
=
indices_data_ptr
[
picked_eids_data_ptr
[
i
]];
}
}));
if
(
type_per_edge_
.
has_value
())
{
AT_DISPATCH_INTEGRAL_TYPES
(
subgraph_type_per_edge
.
value
().
scalar_type
(),
"IndexSelectTypePerEdge"
,
([
&
]
{
auto
subgraph_type_per_edge_data_ptr
=
subgraph_type_per_edge
.
value
()
.
data_ptr
<
scalar_t
>
();
auto
type_per_edge_data_ptr
=
type_per_edge_
.
value
().
data_ptr
<
scalar_t
>
();
for
(
auto
i
=
picked_offset
;
i
<
picked_offset
+
picked_number
;
++
i
)
{
subgraph_type_per_edge_data_ptr
[
i
]
=
type_per_edge_data_ptr
[
picked_eids_data_ptr
[
i
]];
}
}));
}
}
}));
}
if
(
type_per_edge_
.
has_value
())
{
}
AT_DISPATCH_INTEGRAL_TYPES
(
});
subgraph_type_per_edge
.
value
().
scalar_type
(),
}));
"IndexSelectTypePerEdge"
,
([
&
]
{
auto
subgraph_type_per_edge_data_ptr
=
subgraph_type_per_edge
.
value
()
.
data_ptr
<
scalar_t
>
();
auto
type_per_edge_data_ptr
=
type_per_edge_
.
value
().
data_ptr
<
scalar_t
>
();
for
(
auto
i
=
picked_offset
;
i
<
picked_offset
+
picked_number
;
++
i
)
{
subgraph_type_per_edge_data_ptr
[
i
]
=
type_per_edge_data_ptr
[
picked_eids_data_ptr
[
i
]];
}
}));
}
}
}
});
}));
}));
torch
::
optional
<
torch
::
Tensor
>
subgraph_reverse_edge_ids
=
torch
::
nullopt
;
torch
::
optional
<
torch
::
Tensor
>
subgraph_reverse_edge_ids
=
torch
::
nullopt
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment