Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
f0213d21
You need to sign in or sign up before continuing.
Unverified
Commit
f0213d21
authored
Apr 29, 2024
by
Ramon Zhou
Committed by
GitHub
Apr 29, 2024
Browse files
[GraphBolt] Refactor sampling (#7367)
parent
6b140f28
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
62 additions
and
178 deletions
+62
-178
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
+2
-6
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+60
-172
No files found.
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
View file @
f0213d21
...
@@ -18,6 +18,7 @@ namespace graphbolt {
...
@@ -18,6 +18,7 @@ namespace graphbolt {
namespace
sampling
{
namespace
sampling
{
enum
SamplerType
{
NEIGHBOR
,
LABOR
,
LABOR_DEPENDENT
};
enum
SamplerType
{
NEIGHBOR
,
LABOR
,
LABOR_DEPENDENT
};
enum
TemporalOption
{
NOT_TEMPORAL
,
TEMPORAL
};
constexpr
bool
is_labor
(
SamplerType
S
)
{
constexpr
bool
is_labor
(
SamplerType
S
)
{
return
S
==
SamplerType
::
LABOR
||
S
==
SamplerType
::
LABOR_DEPENDENT
;
return
S
==
SamplerType
::
LABOR
||
S
==
SamplerType
::
LABOR_DEPENDENT
;
...
@@ -413,18 +414,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -413,18 +414,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
SharedMemoryPtr
tensor_metadata_shm
,
SharedMemoryPtr
tensor_data_shm
);
SharedMemoryPtr
tensor_metadata_shm
,
SharedMemoryPtr
tensor_data_shm
);
private:
private:
template
<
typename
NumPickFn
,
typename
PickFn
>
template
<
TemporalOption
Temporal
,
typename
NumPickFn
,
typename
PickFn
>
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
SampleNeighborsImpl
(
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
SampleNeighborsImpl
(
const
torch
::
Tensor
&
seeds
,
const
torch
::
Tensor
&
seeds
,
torch
::
optional
<
std
::
vector
<
int64_t
>>&
seed_offsets
,
torch
::
optional
<
std
::
vector
<
int64_t
>>&
seed_offsets
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
return_eids
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
return_eids
,
NumPickFn
num_pick_fn
,
PickFn
pick_fn
)
const
;
NumPickFn
num_pick_fn
,
PickFn
pick_fn
)
const
;
template
<
typename
NumPickFn
,
typename
PickFn
>
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
TemporalSampleNeighborsImpl
(
const
torch
::
Tensor
&
nodes
,
bool
return_eids
,
NumPickFn
num_pick_fn
,
PickFn
pick_fn
)
const
;
/** @brief CSC format index pointer array. */
/** @brief CSC format index pointer array. */
torch
::
Tensor
indptr_
;
torch
::
Tensor
indptr_
;
...
...
graphbolt/src/fused_csc_sampling_graph.cc
View file @
f0213d21
...
@@ -492,7 +492,7 @@ auto GetTemporalPickFn(
...
@@ -492,7 +492,7 @@ auto GetTemporalPickFn(
};
};
}
}
template
<
typename
NumPickFn
,
typename
PickFn
>
template
<
TemporalOption
Temporal
,
typename
NumPickFn
,
typename
PickFn
>
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
FusedCSCSamplingGraph
::
SampleNeighborsImpl
(
FusedCSCSamplingGraph
::
SampleNeighborsImpl
(
const
torch
::
Tensor
&
seeds
,
const
torch
::
Tensor
&
seeds
,
...
@@ -512,7 +512,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
...
@@ -512,7 +512,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
torch
::
optional
<
torch
::
Tensor
>
edge_offsets
=
torch
::
nullopt
;
torch
::
optional
<
torch
::
Tensor
>
edge_offsets
=
torch
::
nullopt
;
bool
with_seed_offsets
=
seed_offsets
.
has_value
();
bool
with_seed_offsets
=
seed_offsets
.
has_value
();
bool
hetero_with_seed_offsets
=
with_seed_offsets
&&
fanouts
.
size
()
>
1
;
bool
hetero_with_seed_offsets
=
with_seed_offsets
&&
fanouts
.
size
()
>
1
&&
Temporal
==
TemporalOption
::
NOT_TEMPORAL
;
// Get the number of edge types. If it's homo or if the size of fanouts is 1
// Get the number of edge types. If it's homo or if the size of fanouts is 1
// (hetero graph but sampled as a homo graph), set num_etypes as 1.
// (hetero graph but sampled as a homo graph), set num_etypes as 1.
...
@@ -584,24 +585,31 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
...
@@ -584,24 +585,31 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
const
auto
offset
=
indptr_data
[
nid
];
const
auto
offset
=
indptr_data
[
nid
];
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
const
auto
seed_type_id
=
if
constexpr
(
Temporal
==
TemporalOption
::
TEMPORAL
)
{
(
hetero_with_seed_offsets
)
num_picked_neighbors_data_ptr
[
i
+
1
]
=
?
std
::
upper_bound
(
num_neighbors
==
0
seed_offsets
->
begin
(),
seed_offsets
->
end
(),
?
0
i
)
-
:
num_pick_fn
(
i
,
offset
,
num_neighbors
);
seed_offsets
->
begin
()
-
1
}
else
{
:
0
;
const
auto
seed_type_id
=
// `seed_index` indicates the index of the current
(
hetero_with_seed_offsets
)
// seed within the group of seeds which have the same
?
std
::
upper_bound
(
// node type.
seed_offsets
->
begin
(),
const
auto
seed_index
=
seed_offsets
->
end
(),
i
)
-
(
hetero_with_seed_offsets
)
seed_offsets
->
begin
()
-
1
?
i
-
seed_offsets
->
at
(
seed_type_id
)
:
0
;
:
i
;
// `seed_index` indicates the index of the current
num_pick_fn
(
// seed within the group of seeds which have the same
offset
,
num_neighbors
,
// node type.
num_picked_neighbors_data_ptr
+
1
,
seed_index
,
const
auto
seed_index
=
etype_id_to_num_picked_offset
);
(
hetero_with_seed_offsets
)
?
i
-
seed_offsets
->
at
(
seed_type_id
)
:
i
;
num_pick_fn
(
offset
,
num_neighbors
,
num_picked_neighbors_data_ptr
+
1
,
seed_index
,
etype_id_to_num_picked_offset
);
}
}
}
});
});
...
@@ -684,16 +692,30 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
...
@@ -684,16 +692,30 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
:
i
;
:
i
;
// Step 4. Pick neighbors for each node.
// Step 4. Pick neighbors for each node.
picked_number
=
pick_fn
(
if
constexpr
(
Temporal
==
TemporalOption
::
TEMPORAL
)
{
offset
,
num_neighbors
,
picked_eids_data_ptr
,
picked_number
=
num_picked_neighbors_data_ptr
[
i
+
1
];
seed_index
,
subgraph_indptr_data_ptr
,
auto
picked_offset
=
subgraph_indptr_data_ptr
[
i
];
etype_id_to_num_picked_offset
);
if
(
picked_number
>
0
)
{
if
(
!
hetero_with_seed_offsets
)
{
auto
actual_picked_count
=
pick_fn
(
TORCH_CHECK
(
i
,
offset
,
num_neighbors
,
num_picked_neighbors_data_ptr
[
i
+
1
]
==
picked_eids_data_ptr
+
picked_offset
);
picked_number
,
TORCH_CHECK
(
"Actual picked count doesn't match the calculated "
actual_picked_count
==
picked_number
,
"pick number."
);
"Actual picked count doesn't match the calculated"
" pick number."
);
}
}
else
{
picked_number
=
pick_fn
(
offset
,
num_neighbors
,
picked_eids_data_ptr
,
seed_index
,
subgraph_indptr_data_ptr
,
etype_id_to_num_picked_offset
);
if
(
!
hetero_with_seed_offsets
)
{
TORCH_CHECK
(
num_picked_neighbors_data_ptr
[
i
+
1
]
==
picked_number
,
"Actual picked count doesn't match the calculated"
" pick number."
);
}
}
}
// Step 5. Calculate other attributes and return the
// Step 5. Calculate other attributes and return the
...
@@ -779,141 +801,6 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
...
@@ -779,141 +801,6 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
subgraph_reverse_edge_ids
,
subgraph_type_per_edge
,
edge_offsets
);
subgraph_reverse_edge_ids
,
subgraph_type_per_edge
,
edge_offsets
);
}
}
template
<
typename
NumPickFn
,
typename
PickFn
>
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
FusedCSCSamplingGraph
::
TemporalSampleNeighborsImpl
(
const
torch
::
Tensor
&
nodes
,
bool
return_eids
,
NumPickFn
num_pick_fn
,
PickFn
pick_fn
)
const
{
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
const
auto
indptr_options
=
indptr_
.
options
();
torch
::
Tensor
num_picked_neighbors_per_node
=
torch
::
empty
({
num_nodes
+
1
},
indptr_options
);
// Calculate GrainSize for parallel_for.
// Set the default grain size to 64.
const
int64_t
grain_size
=
64
;
torch
::
Tensor
picked_eids
;
torch
::
Tensor
subgraph_indptr
;
torch
::
Tensor
subgraph_indices
;
torch
::
optional
<
torch
::
Tensor
>
subgraph_type_per_edge
=
torch
::
nullopt
;
AT_DISPATCH_INDEX_TYPES
(
indptr_
.
scalar_type
(),
"SampleNeighborsImplWrappedWithIndptr"
,
([
&
]
{
using
indptr_t
=
index_t
;
AT_DISPATCH_INDEX_TYPES
(
nodes
.
scalar_type
(),
"SampleNeighborsImplWrappedWithNodes"
,
([
&
]
{
using
nodes_t
=
index_t
;
const
auto
indptr_data
=
indptr_
.
data_ptr
<
indptr_t
>
();
auto
num_picked_neighbors_data_ptr
=
num_picked_neighbors_per_node
.
data_ptr
<
indptr_t
>
();
num_picked_neighbors_data_ptr
[
0
]
=
0
;
const
auto
nodes_data_ptr
=
nodes
.
data_ptr
<
nodes_t
>
();
// Step 1. Calculate pick number of each node.
torch
::
parallel_for
(
0
,
num_nodes
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
const
auto
nid
=
nodes_data_ptr
[
i
];
TORCH_CHECK
(
nid
>=
0
&&
nid
<
NumNodes
(),
"The seed nodes' IDs should fall within the range of "
"the "
"graph's node IDs."
);
const
auto
offset
=
indptr_data
[
nid
];
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
num_picked_neighbors_data_ptr
[
i
+
1
]
=
num_neighbors
==
0
?
0
:
num_pick_fn
(
i
,
offset
,
num_neighbors
);
}
});
// Step 2. Calculate prefix sum to get total length and offsets of
// each node. It's also the indptr of the generated subgraph.
subgraph_indptr
=
num_picked_neighbors_per_node
.
cumsum
(
0
,
indptr_
.
scalar_type
());
// Step 3. Allocate the tensor for picked neighbors.
const
auto
total_length
=
subgraph_indptr
.
data_ptr
<
indptr_t
>
()[
num_nodes
];
picked_eids
=
torch
::
empty
({
total_length
},
indptr_options
);
subgraph_indices
=
torch
::
empty
({
total_length
},
indices_
.
options
());
if
(
type_per_edge_
.
has_value
())
{
subgraph_type_per_edge
=
torch
::
empty
(
{
total_length
},
type_per_edge_
.
value
().
options
());
}
// Step 4. Pick neighbors for each node.
auto
picked_eids_data_ptr
=
picked_eids
.
data_ptr
<
indptr_t
>
();
auto
subgraph_indptr_data_ptr
=
subgraph_indptr
.
data_ptr
<
indptr_t
>
();
torch
::
parallel_for
(
0
,
num_nodes
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
const
auto
nid
=
nodes_data_ptr
[
i
];
const
auto
offset
=
indptr_data
[
nid
];
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
const
auto
picked_number
=
num_picked_neighbors_data_ptr
[
i
+
1
];
const
auto
picked_offset
=
subgraph_indptr_data_ptr
[
i
];
if
(
picked_number
>
0
)
{
auto
actual_picked_count
=
pick_fn
(
i
,
offset
,
num_neighbors
,
picked_eids_data_ptr
+
picked_offset
);
TORCH_CHECK
(
actual_picked_count
==
picked_number
,
"Actual picked count doesn't match the calculated "
"pick "
"number."
);
// Step 5. Calculate other attributes and return the
// subgraph.
AT_DISPATCH_INDEX_TYPES
(
subgraph_indices
.
scalar_type
(),
"IndexSelectSubgraphIndices"
,
([
&
]
{
auto
subgraph_indices_data_ptr
=
subgraph_indices
.
data_ptr
<
index_t
>
();
auto
indices_data_ptr
=
indices_
.
data_ptr
<
index_t
>
();
for
(
auto
i
=
picked_offset
;
i
<
picked_offset
+
picked_number
;
++
i
)
{
subgraph_indices_data_ptr
[
i
]
=
indices_data_ptr
[
picked_eids_data_ptr
[
i
]];
}
}));
if
(
type_per_edge_
.
has_value
())
{
AT_DISPATCH_INTEGRAL_TYPES
(
subgraph_type_per_edge
.
value
().
scalar_type
(),
"IndexSelectTypePerEdge"
,
([
&
]
{
auto
subgraph_type_per_edge_data_ptr
=
subgraph_type_per_edge
.
value
()
.
data_ptr
<
scalar_t
>
();
auto
type_per_edge_data_ptr
=
type_per_edge_
.
value
().
data_ptr
<
scalar_t
>
();
for
(
auto
i
=
picked_offset
;
i
<
picked_offset
+
picked_number
;
++
i
)
{
subgraph_type_per_edge_data_ptr
[
i
]
=
type_per_edge_data_ptr
[
picked_eids_data_ptr
[
i
]];
}
}));
}
}
}
});
}));
}));
torch
::
optional
<
torch
::
Tensor
>
subgraph_reverse_edge_ids
=
torch
::
nullopt
;
if
(
return_eids
)
subgraph_reverse_edge_ids
=
std
::
move
(
picked_eids
);
return
c10
::
make_intrusive
<
FusedSampledSubgraph
>
(
subgraph_indptr
,
subgraph_indices
,
nodes
,
torch
::
nullopt
,
subgraph_reverse_edge_ids
,
subgraph_type_per_edge
);
}
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
FusedCSCSamplingGraph
::
SampleNeighbors
(
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
FusedCSCSamplingGraph
::
SampleNeighbors
(
torch
::
optional
<
torch
::
Tensor
>
seeds
,
torch
::
optional
<
torch
::
Tensor
>
seeds
,
torch
::
optional
<
std
::
vector
<
int64_t
>>
seed_offsets
,
torch
::
optional
<
std
::
vector
<
int64_t
>>
seed_offsets
,
...
@@ -969,7 +856,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
...
@@ -969,7 +856,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
indices_
,
indices_
,
{
random_seed
.
value
(),
static_cast
<
float
>
(
seed2_contribution
)},
{
random_seed
.
value
(),
static_cast
<
float
>
(
seed2_contribution
)},
NumNodes
()};
NumNodes
()};
return
SampleNeighborsImpl
(
return
SampleNeighborsImpl
<
TemporalOption
::
NOT_TEMPORAL
>
(
seeds
.
value
(),
seed_offsets
,
fanouts
,
return_eids
,
seeds
.
value
(),
seed_offsets
,
fanouts
,
return_eids
,
GetNumPickFn
(
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
,
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
,
...
@@ -990,7 +877,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
...
@@ -990,7 +877,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
NumNodes
()};
NumNodes
()};
}
}
}();
}();
return
SampleNeighborsImpl
(
return
SampleNeighborsImpl
<
TemporalOption
::
NOT_TEMPORAL
>
(
seeds
.
value
(),
seed_offsets
,
fanouts
,
return_eids
,
seeds
.
value
(),
seed_offsets
,
fanouts
,
return_eids
,
GetNumPickFn
(
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
,
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
,
...
@@ -1001,7 +888,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
...
@@ -1001,7 +888,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
}
}
}
else
{
}
else
{
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
;
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
;
return
SampleNeighborsImpl
(
return
SampleNeighborsImpl
<
TemporalOption
::
NOT_TEMPORAL
>
(
seeds
.
value
(),
seed_offsets
,
fanouts
,
return_eids
,
seeds
.
value
(),
seed_offsets
,
fanouts
,
return_eids
,
GetNumPickFn
(
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
,
with_seed_offsets
),
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
,
with_seed_offsets
),
...
@@ -1019,6 +906,7 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
...
@@ -1019,6 +906,7 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
bool
return_eids
,
torch
::
optional
<
std
::
string
>
probs_name
,
bool
return_eids
,
torch
::
optional
<
std
::
string
>
probs_name
,
torch
::
optional
<
std
::
string
>
node_timestamp_attr_name
,
torch
::
optional
<
std
::
string
>
node_timestamp_attr_name
,
torch
::
optional
<
std
::
string
>
edge_timestamp_attr_name
)
const
{
torch
::
optional
<
std
::
string
>
edge_timestamp_attr_name
)
const
{
torch
::
optional
<
std
::
vector
<
int64_t
>>
seed_offsets
=
torch
::
nullopt
;
// 1. Get probs_or_mask.
// 1. Get probs_or_mask.
auto
probs_or_mask
=
this
->
EdgeAttribute
(
probs_name
);
auto
probs_or_mask
=
this
->
EdgeAttribute
(
probs_name
);
if
(
probs_name
.
has_value
())
{
if
(
probs_name
.
has_value
())
{
...
@@ -1039,8 +927,8 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
...
@@ -1039,8 +927,8 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
const
int64_t
random_seed
=
RandomEngine
::
ThreadLocal
()
->
RandInt
(
const
int64_t
random_seed
=
RandomEngine
::
ThreadLocal
()
->
RandInt
(
static_cast
<
int64_t
>
(
0
),
std
::
numeric_limits
<
int64_t
>::
max
());
static_cast
<
int64_t
>
(
0
),
std
::
numeric_limits
<
int64_t
>::
max
());
SamplerArgs
<
SamplerType
::
LABOR
>
args
{
indices_
,
random_seed
,
NumNodes
()};
SamplerArgs
<
SamplerType
::
LABOR
>
args
{
indices_
,
random_seed
,
NumNodes
()};
return
Temporal
SampleNeighborsImpl
(
return
SampleNeighborsImpl
<
TemporalOption
::
TEMPORAL
>
(
input_nodes
,
return_eids
,
input_nodes
,
seed_offsets
,
fanouts
,
return_eids
,
GetTemporalNumPickFn
(
GetTemporalNumPickFn
(
input_nodes_timestamp
,
this
->
indices_
,
fanouts
,
replace
,
input_nodes_timestamp
,
this
->
indices_
,
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
),
type_per_edge_
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
),
...
@@ -1050,8 +938,8 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
...
@@ -1050,8 +938,8 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
edge_timestamp
,
args
));
edge_timestamp
,
args
));
}
else
{
}
else
{
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
;
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
;
return
Temporal
SampleNeighborsImpl
(
return
SampleNeighborsImpl
<
TemporalOption
::
TEMPORAL
>
(
input_nodes
,
return_eids
,
input_nodes
,
seed_offsets
,
fanouts
,
return_eids
,
GetTemporalNumPickFn
(
GetTemporalNumPickFn
(
input_nodes_timestamp
,
this
->
indices_
,
fanouts
,
replace
,
input_nodes_timestamp
,
this
->
indices_
,
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
),
type_per_edge_
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment