Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
f0213d21
Unverified
Commit
f0213d21
authored
Apr 29, 2024
by
Ramon Zhou
Committed by
GitHub
Apr 29, 2024
Browse files
[GraphBolt] Refactor sampling (#7367)
parent
6b140f28
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
62 additions
and
178 deletions
+62
-178
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
+2
-6
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+60
-172
No files found.
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
View file @
f0213d21
...
@@ -18,6 +18,7 @@ namespace graphbolt {
...
@@ -18,6 +18,7 @@ namespace graphbolt {
namespace
sampling
{
namespace
sampling
{
enum
SamplerType
{
NEIGHBOR
,
LABOR
,
LABOR_DEPENDENT
};
enum
SamplerType
{
NEIGHBOR
,
LABOR
,
LABOR_DEPENDENT
};
enum
TemporalOption
{
NOT_TEMPORAL
,
TEMPORAL
};
constexpr
bool
is_labor
(
SamplerType
S
)
{
constexpr
bool
is_labor
(
SamplerType
S
)
{
return
S
==
SamplerType
::
LABOR
||
S
==
SamplerType
::
LABOR_DEPENDENT
;
return
S
==
SamplerType
::
LABOR
||
S
==
SamplerType
::
LABOR_DEPENDENT
;
...
@@ -413,18 +414,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -413,18 +414,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
SharedMemoryPtr
tensor_metadata_shm
,
SharedMemoryPtr
tensor_data_shm
);
SharedMemoryPtr
tensor_metadata_shm
,
SharedMemoryPtr
tensor_data_shm
);
private:
private:
template
<
typename
NumPickFn
,
typename
PickFn
>
template
<
TemporalOption
Temporal
,
typename
NumPickFn
,
typename
PickFn
>
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
SampleNeighborsImpl
(
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
SampleNeighborsImpl
(
const
torch
::
Tensor
&
seeds
,
const
torch
::
Tensor
&
seeds
,
torch
::
optional
<
std
::
vector
<
int64_t
>>&
seed_offsets
,
torch
::
optional
<
std
::
vector
<
int64_t
>>&
seed_offsets
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
return_eids
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
return_eids
,
NumPickFn
num_pick_fn
,
PickFn
pick_fn
)
const
;
NumPickFn
num_pick_fn
,
PickFn
pick_fn
)
const
;
template
<
typename
NumPickFn
,
typename
PickFn
>
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
TemporalSampleNeighborsImpl
(
const
torch
::
Tensor
&
nodes
,
bool
return_eids
,
NumPickFn
num_pick_fn
,
PickFn
pick_fn
)
const
;
/** @brief CSC format index pointer array. */
/** @brief CSC format index pointer array. */
torch
::
Tensor
indptr_
;
torch
::
Tensor
indptr_
;
...
...
graphbolt/src/fused_csc_sampling_graph.cc
View file @
f0213d21
...
@@ -492,7 +492,7 @@ auto GetTemporalPickFn(
...
@@ -492,7 +492,7 @@ auto GetTemporalPickFn(
};
};
}
}
template
<
typename
NumPickFn
,
typename
PickFn
>
template
<
TemporalOption
Temporal
,
typename
NumPickFn
,
typename
PickFn
>
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
FusedCSCSamplingGraph
::
SampleNeighborsImpl
(
FusedCSCSamplingGraph
::
SampleNeighborsImpl
(
const
torch
::
Tensor
&
seeds
,
const
torch
::
Tensor
&
seeds
,
...
@@ -512,7 +512,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
...
@@ -512,7 +512,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
torch
::
optional
<
torch
::
Tensor
>
edge_offsets
=
torch
::
nullopt
;
torch
::
optional
<
torch
::
Tensor
>
edge_offsets
=
torch
::
nullopt
;
bool
with_seed_offsets
=
seed_offsets
.
has_value
();
bool
with_seed_offsets
=
seed_offsets
.
has_value
();
bool
hetero_with_seed_offsets
=
with_seed_offsets
&&
fanouts
.
size
()
>
1
;
bool
hetero_with_seed_offsets
=
with_seed_offsets
&&
fanouts
.
size
()
>
1
&&
Temporal
==
TemporalOption
::
NOT_TEMPORAL
;
// Get the number of edge types. If it's homo or if the size of fanouts is 1
// Get the number of edge types. If it's homo or if the size of fanouts is 1
// (hetero graph but sampled as a homo graph), set num_etypes as 1.
// (hetero graph but sampled as a homo graph), set num_etypes as 1.
...
@@ -584,24 +585,31 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
...
@@ -584,24 +585,31 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
const
auto
offset
=
indptr_data
[
nid
];
const
auto
offset
=
indptr_data
[
nid
];
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
const
auto
seed_type_id
=
if
constexpr
(
Temporal
==
TemporalOption
::
TEMPORAL
)
{
(
hetero_with_seed_offsets
)
num_picked_neighbors_data_ptr
[
i
+
1
]
=
?
std
::
upper_bound
(
num_neighbors
==
0
seed_offsets
->
begin
(),
seed_offsets
->
end
(),
?
0
i
)
-
:
num_pick_fn
(
i
,
offset
,
num_neighbors
);
seed_offsets
->
begin
()
-
1
}
else
{
:
0
;
const
auto
seed_type_id
=
// `seed_index` indicates the index of the current
(
hetero_with_seed_offsets
)
// seed within the group of seeds which have the same
?
std
::
upper_bound
(
// node type.
seed_offsets
->
begin
(),
const
auto
seed_index
=
seed_offsets
->
end
(),
i
)
-
(
hetero_with_seed_offsets
)
seed_offsets
->
begin
()
-
1
?
i
-
seed_offsets
->
at
(
seed_type_id
)
:
0
;
:
i
;
// `seed_index` indicates the index of the current
num_pick_fn
(
// seed within the group of seeds which have the same
offset
,
num_neighbors
,
// node type.
num_picked_neighbors_data_ptr
+
1
,
seed_index
,
const
auto
seed_index
=
etype_id_to_num_picked_offset
);
(
hetero_with_seed_offsets
)
?
i
-
seed_offsets
->
at
(
seed_type_id
)
:
i
;
num_pick_fn
(
offset
,
num_neighbors
,
num_picked_neighbors_data_ptr
+
1
,
seed_index
,
etype_id_to_num_picked_offset
);
}
}
}
});
});
...
@@ -684,16 +692,30 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
...
@@ -684,16 +692,30 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
:
i
;
:
i
;
// Step 4. Pick neighbors for each node.
// Step 4. Pick neighbors for each node.
picked_number
=
pick_fn
(
if
constexpr
(
Temporal
==
TemporalOption
::
TEMPORAL
)
{
offset
,
num_neighbors
,
picked_eids_data_ptr
,
picked_number
=
num_picked_neighbors_data_ptr
[
i
+
1
];
seed_index
,
subgraph_indptr_data_ptr
,
auto
picked_offset
=
subgraph_indptr_data_ptr
[
i
];
etype_id_to_num_picked_offset
);
if
(
picked_number
>
0
)
{
if
(
!
hetero_with_seed_offsets
)
{
auto
actual_picked_count
=
pick_fn
(
TORCH_CHECK
(
i
,
offset
,
num_neighbors
,
num_picked_neighbors_data_ptr
[
i
+
1
]
==
picked_eids_data_ptr
+
picked_offset
);
picked_number
,
TORCH_CHECK
(
"Actual picked count doesn't match the calculated "
actual_picked_count
==
picked_number
,
"pick number."
);
"Actual picked count doesn't match the calculated"
" pick number."
);
}
}
else
{
picked_number
=
pick_fn
(
offset
,
num_neighbors
,
picked_eids_data_ptr
,
seed_index
,
subgraph_indptr_data_ptr
,
etype_id_to_num_picked_offset
);
if
(
!
hetero_with_seed_offsets
)
{
TORCH_CHECK
(
num_picked_neighbors_data_ptr
[
i
+
1
]
==
picked_number
,
"Actual picked count doesn't match the calculated"
" pick number."
);
}
}
}
// Step 5. Calculate other attributes and return the
// Step 5. Calculate other attributes and return the
...
@@ -779,141 +801,6 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
...
@@ -779,141 +801,6 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
subgraph_reverse_edge_ids
,
subgraph_type_per_edge
,
edge_offsets
);
subgraph_reverse_edge_ids
,
subgraph_type_per_edge
,
edge_offsets
);
}
}
template
<
typename
NumPickFn
,
typename
PickFn
>
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
FusedCSCSamplingGraph
::
TemporalSampleNeighborsImpl
(
const
torch
::
Tensor
&
nodes
,
bool
return_eids
,
NumPickFn
num_pick_fn
,
PickFn
pick_fn
)
const
{
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
const
auto
indptr_options
=
indptr_
.
options
();
torch
::
Tensor
num_picked_neighbors_per_node
=
torch
::
empty
({
num_nodes
+
1
},
indptr_options
);
// Calculate GrainSize for parallel_for.
// Set the default grain size to 64.
const
int64_t
grain_size
=
64
;
torch
::
Tensor
picked_eids
;
torch
::
Tensor
subgraph_indptr
;
torch
::
Tensor
subgraph_indices
;
torch
::
optional
<
torch
::
Tensor
>
subgraph_type_per_edge
=
torch
::
nullopt
;
AT_DISPATCH_INDEX_TYPES
(
indptr_
.
scalar_type
(),
"SampleNeighborsImplWrappedWithIndptr"
,
([
&
]
{
using
indptr_t
=
index_t
;
AT_DISPATCH_INDEX_TYPES
(
nodes
.
scalar_type
(),
"SampleNeighborsImplWrappedWithNodes"
,
([
&
]
{
using
nodes_t
=
index_t
;
const
auto
indptr_data
=
indptr_
.
data_ptr
<
indptr_t
>
();
auto
num_picked_neighbors_data_ptr
=
num_picked_neighbors_per_node
.
data_ptr
<
indptr_t
>
();
num_picked_neighbors_data_ptr
[
0
]
=
0
;
const
auto
nodes_data_ptr
=
nodes
.
data_ptr
<
nodes_t
>
();
// Step 1. Calculate pick number of each node.
torch
::
parallel_for
(
0
,
num_nodes
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
const
auto
nid
=
nodes_data_ptr
[
i
];
TORCH_CHECK
(
nid
>=
0
&&
nid
<
NumNodes
(),
"The seed nodes' IDs should fall within the range of "
"the "
"graph's node IDs."
);
const
auto
offset
=
indptr_data
[
nid
];
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
num_picked_neighbors_data_ptr
[
i
+
1
]
=
num_neighbors
==
0
?
0
:
num_pick_fn
(
i
,
offset
,
num_neighbors
);
}
});
// Step 2. Calculate prefix sum to get total length and offsets of
// each node. It's also the indptr of the generated subgraph.
subgraph_indptr
=
num_picked_neighbors_per_node
.
cumsum
(
0
,
indptr_
.
scalar_type
());
// Step 3. Allocate the tensor for picked neighbors.
const
auto
total_length
=
subgraph_indptr
.
data_ptr
<
indptr_t
>
()[
num_nodes
];
picked_eids
=
torch
::
empty
({
total_length
},
indptr_options
);
subgraph_indices
=
torch
::
empty
({
total_length
},
indices_
.
options
());
if
(
type_per_edge_
.
has_value
())
{
subgraph_type_per_edge
=
torch
::
empty
(
{
total_length
},
type_per_edge_
.
value
().
options
());
}
// Step 4. Pick neighbors for each node.
auto
picked_eids_data_ptr
=
picked_eids
.
data_ptr
<
indptr_t
>
();
auto
subgraph_indptr_data_ptr
=
subgraph_indptr
.
data_ptr
<
indptr_t
>
();
torch
::
parallel_for
(
0
,
num_nodes
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
const
auto
nid
=
nodes_data_ptr
[
i
];
const
auto
offset
=
indptr_data
[
nid
];
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
const
auto
picked_number
=
num_picked_neighbors_data_ptr
[
i
+
1
];
const
auto
picked_offset
=
subgraph_indptr_data_ptr
[
i
];
if
(
picked_number
>
0
)
{
auto
actual_picked_count
=
pick_fn
(
i
,
offset
,
num_neighbors
,
picked_eids_data_ptr
+
picked_offset
);
TORCH_CHECK
(
actual_picked_count
==
picked_number
,
"Actual picked count doesn't match the calculated "
"pick "
"number."
);
// Step 5. Calculate other attributes and return the
// subgraph.
AT_DISPATCH_INDEX_TYPES
(
subgraph_indices
.
scalar_type
(),
"IndexSelectSubgraphIndices"
,
([
&
]
{
auto
subgraph_indices_data_ptr
=
subgraph_indices
.
data_ptr
<
index_t
>
();
auto
indices_data_ptr
=
indices_
.
data_ptr
<
index_t
>
();
for
(
auto
i
=
picked_offset
;
i
<
picked_offset
+
picked_number
;
++
i
)
{
subgraph_indices_data_ptr
[
i
]
=
indices_data_ptr
[
picked_eids_data_ptr
[
i
]];
}
}));
if
(
type_per_edge_
.
has_value
())
{
AT_DISPATCH_INTEGRAL_TYPES
(
subgraph_type_per_edge
.
value
().
scalar_type
(),
"IndexSelectTypePerEdge"
,
([
&
]
{
auto
subgraph_type_per_edge_data_ptr
=
subgraph_type_per_edge
.
value
()
.
data_ptr
<
scalar_t
>
();
auto
type_per_edge_data_ptr
=
type_per_edge_
.
value
().
data_ptr
<
scalar_t
>
();
for
(
auto
i
=
picked_offset
;
i
<
picked_offset
+
picked_number
;
++
i
)
{
subgraph_type_per_edge_data_ptr
[
i
]
=
type_per_edge_data_ptr
[
picked_eids_data_ptr
[
i
]];
}
}));
}
}
}
});
}));
}));
torch
::
optional
<
torch
::
Tensor
>
subgraph_reverse_edge_ids
=
torch
::
nullopt
;
if
(
return_eids
)
subgraph_reverse_edge_ids
=
std
::
move
(
picked_eids
);
return
c10
::
make_intrusive
<
FusedSampledSubgraph
>
(
subgraph_indptr
,
subgraph_indices
,
nodes
,
torch
::
nullopt
,
subgraph_reverse_edge_ids
,
subgraph_type_per_edge
);
}
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
FusedCSCSamplingGraph
::
SampleNeighbors
(
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
FusedCSCSamplingGraph
::
SampleNeighbors
(
torch
::
optional
<
torch
::
Tensor
>
seeds
,
torch
::
optional
<
torch
::
Tensor
>
seeds
,
torch
::
optional
<
std
::
vector
<
int64_t
>>
seed_offsets
,
torch
::
optional
<
std
::
vector
<
int64_t
>>
seed_offsets
,
...
@@ -969,7 +856,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
...
@@ -969,7 +856,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
indices_
,
indices_
,
{
random_seed
.
value
(),
static_cast
<
float
>
(
seed2_contribution
)},
{
random_seed
.
value
(),
static_cast
<
float
>
(
seed2_contribution
)},
NumNodes
()};
NumNodes
()};
return
SampleNeighborsImpl
(
return
SampleNeighborsImpl
<
TemporalOption
::
NOT_TEMPORAL
>
(
seeds
.
value
(),
seed_offsets
,
fanouts
,
return_eids
,
seeds
.
value
(),
seed_offsets
,
fanouts
,
return_eids
,
GetNumPickFn
(
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
,
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
,
...
@@ -990,7 +877,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
...
@@ -990,7 +877,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
NumNodes
()};
NumNodes
()};
}
}
}();
}();
return
SampleNeighborsImpl
(
return
SampleNeighborsImpl
<
TemporalOption
::
NOT_TEMPORAL
>
(
seeds
.
value
(),
seed_offsets
,
fanouts
,
return_eids
,
seeds
.
value
(),
seed_offsets
,
fanouts
,
return_eids
,
GetNumPickFn
(
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
,
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
,
...
@@ -1001,7 +888,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
...
@@ -1001,7 +888,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
}
}
}
else
{
}
else
{
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
;
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
;
return
SampleNeighborsImpl
(
return
SampleNeighborsImpl
<
TemporalOption
::
NOT_TEMPORAL
>
(
seeds
.
value
(),
seed_offsets
,
fanouts
,
return_eids
,
seeds
.
value
(),
seed_offsets
,
fanouts
,
return_eids
,
GetNumPickFn
(
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
,
with_seed_offsets
),
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
,
with_seed_offsets
),
...
@@ -1019,6 +906,7 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
...
@@ -1019,6 +906,7 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
bool
return_eids
,
torch
::
optional
<
std
::
string
>
probs_name
,
bool
return_eids
,
torch
::
optional
<
std
::
string
>
probs_name
,
torch
::
optional
<
std
::
string
>
node_timestamp_attr_name
,
torch
::
optional
<
std
::
string
>
node_timestamp_attr_name
,
torch
::
optional
<
std
::
string
>
edge_timestamp_attr_name
)
const
{
torch
::
optional
<
std
::
string
>
edge_timestamp_attr_name
)
const
{
torch
::
optional
<
std
::
vector
<
int64_t
>>
seed_offsets
=
torch
::
nullopt
;
// 1. Get probs_or_mask.
// 1. Get probs_or_mask.
auto
probs_or_mask
=
this
->
EdgeAttribute
(
probs_name
);
auto
probs_or_mask
=
this
->
EdgeAttribute
(
probs_name
);
if
(
probs_name
.
has_value
())
{
if
(
probs_name
.
has_value
())
{
...
@@ -1039,8 +927,8 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
...
@@ -1039,8 +927,8 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
const
int64_t
random_seed
=
RandomEngine
::
ThreadLocal
()
->
RandInt
(
const
int64_t
random_seed
=
RandomEngine
::
ThreadLocal
()
->
RandInt
(
static_cast
<
int64_t
>
(
0
),
std
::
numeric_limits
<
int64_t
>::
max
());
static_cast
<
int64_t
>
(
0
),
std
::
numeric_limits
<
int64_t
>::
max
());
SamplerArgs
<
SamplerType
::
LABOR
>
args
{
indices_
,
random_seed
,
NumNodes
()};
SamplerArgs
<
SamplerType
::
LABOR
>
args
{
indices_
,
random_seed
,
NumNodes
()};
return
Temporal
SampleNeighborsImpl
(
return
SampleNeighborsImpl
<
TemporalOption
::
TEMPORAL
>
(
input_nodes
,
return_eids
,
input_nodes
,
seed_offsets
,
fanouts
,
return_eids
,
GetTemporalNumPickFn
(
GetTemporalNumPickFn
(
input_nodes_timestamp
,
this
->
indices_
,
fanouts
,
replace
,
input_nodes_timestamp
,
this
->
indices_
,
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
),
type_per_edge_
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
),
...
@@ -1050,8 +938,8 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
...
@@ -1050,8 +938,8 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
edge_timestamp
,
args
));
edge_timestamp
,
args
));
}
else
{
}
else
{
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
;
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
;
return
Temporal
SampleNeighborsImpl
(
return
SampleNeighborsImpl
<
TemporalOption
::
TEMPORAL
>
(
input_nodes
,
return_eids
,
input_nodes
,
seed_offsets
,
fanouts
,
return_eids
,
GetTemporalNumPickFn
(
GetTemporalNumPickFn
(
input_nodes_timestamp
,
this
->
indices_
,
fanouts
,
replace
,
input_nodes_timestamp
,
this
->
indices_
,
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
),
type_per_edge_
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment