Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b1153db9
Unverified
Commit
b1153db9
authored
Aug 23, 2023
by
Ramon Zhou
Committed by
GitHub
Aug 23, 2023
Browse files
[Graphbolt] Rewrite sampling process to eliminate torch::cat (#6152)
parent
64df37f7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
132 additions
and
84 deletions
+132
-84
graphbolt/include/graphbolt/csc_sampling_graph.h
graphbolt/include/graphbolt/csc_sampling_graph.h
+4
-4
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+128
-80
No files found.
graphbolt/include/graphbolt/csc_sampling_graph.h
View file @
b1153db9
...
...
@@ -357,14 +357,14 @@ int64_t NumPickByEtype(
* should be put. Enough memory space should be allocated in advance.
*/
template
<
typename
PickedType
>
void
Pick
(
int64_t
Pick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
,
PickedType
*
picked_data_ptr
);
template
<
typename
PickedType
>
void
Pick
(
int64_t
Pick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
...
...
@@ -398,7 +398,7 @@ void Pick(
* should be put. Enough memory space should be allocated in advance.
*/
template
<
SamplerType
S
,
typename
PickedType
>
void
PickByEtype
(
int64_t
PickByEtype
(
int64_t
offset
,
int64_t
num_neighbors
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
Tensor
&
type_per_edge
,
...
...
@@ -408,7 +408,7 @@ void PickByEtype(
template
<
bool
NonUniform
,
bool
Replace
,
typename
ProbsType
=
float
,
typename
PickedType
>
void
LaborPick
(
int64_t
LaborPick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
...
...
graphbolt/src/csc_sampling_graph.cc
View file @
b1153db9
...
...
@@ -222,28 +222,30 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
const
torch
::
Tensor
&
nodes
,
bool
return_eids
,
NumPickFn
num_pick_fn
,
PickFn
pick_fn
)
const
{
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
const
int64_t
num_threads
=
torch
::
get_num_threads
();
std
::
vector
<
torch
::
Tensor
>
picked_neighbors_per_thread
(
num_threads
);
const
auto
indptr_options
=
indptr_
.
options
();
torch
::
Tensor
num_picked_neighbors_per_node
=
torch
::
zeros
({
num_nodes
+
1
},
indptr_
.
options
()
);
torch
::
empty
({
num_nodes
+
1
},
indptr_options
);
// Calculate GrainSize for parallel_for.
// Set the default grain size to 64.
const
int64_t
grain_size
=
64
;
torch
::
Tensor
picked_eids
;
torch
::
Tensor
subgraph_indptr
;
torch
::
Tensor
subgraph_indices
;
torch
::
optional
<
torch
::
Tensor
>
subgraph_type_per_edge
=
torch
::
nullopt
;
AT_DISPATCH_INTEGRAL_TYPES
(
indptr_
.
scalar_type
(),
"parallel_for"
,
([
&
]
{
indptr_
.
scalar_type
(),
"SampleNeighborsImpl"
,
([
&
]
{
const
scalar_t
*
indptr_data
=
indptr_
.
data_ptr
<
scalar_t
>
();
auto
num_picked_neighbors_data_ptr
=
num_picked_neighbors_per_node
.
data_ptr
<
scalar_t
>
();
num_picked_neighbors_data_ptr
[
0
]
=
0
;
const
auto
nodes_data_ptr
=
nodes
.
data_ptr
<
int64_t
>
();
// Step 1. Calculate pick number of each node.
torch
::
parallel_for
(
0
,
num_nodes
,
grain_size
,
[
&
](
scalar_t
begin
,
scalar_t
end
)
{
const
auto
indptr_options
=
indptr_
.
options
();
const
scalar_t
*
indptr_data
=
indptr_
.
data_ptr
<
scalar_t
>
();
// Get current thread id.
auto
thread_id
=
torch
::
get_thread_num
();
int64_t
local_grain_size
=
end
-
begin
;
std
::
vector
<
torch
::
Tensor
>
picked_neighbors_cur_thread
(
local_grain_size
);
const
auto
nodes_data_ptr
=
nodes
.
data_ptr
<
int64_t
>
();
for
(
scalar_t
i
=
begin
;
i
<
end
;
++
i
)
{
0
,
num_nodes
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
const
auto
nid
=
nodes_data_ptr
[
i
];
TORCH_CHECK
(
nid
>=
0
&&
nid
<
NumNodes
(),
...
...
@@ -252,49 +254,82 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
const
auto
offset
=
indptr_data
[
nid
];
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
if
(
num_neighbors
==
0
)
{
// To avoid crashing during concatenation in the master
// thread, initializing with empty tensors.
picked_neighbors_cur_thread
[
i
-
begin
]
=
torch
::
tensor
({},
indptr_options
);
continue
;
}
num_picked_neighbors_data_ptr
[
i
+
1
]
=
num_neighbors
==
0
?
0
:
num_pick_fn
(
offset
,
num_neighbors
);
}
});
// Step 2. Calculate prefix sum to get total length and offsets of each
// node. It's also the indptr of the generated subgraph.
subgraph_indptr
=
torch
::
cumsum
(
num_picked_neighbors_per_node
,
0
);
// Step 3. Allocate the tensor for picked neighbors.
const
auto
total_length
=
subgraph_indptr
.
data_ptr
<
scalar_t
>
()[
num_nodes
];
picked_eids
=
torch
::
empty
({
total_length
},
indptr_options
);
subgraph_indices
=
torch
::
empty
({
total_length
},
indices_
.
options
());
if
(
type_per_edge_
.
has_value
())
{
subgraph_type_per_edge
=
torch
::
empty
({
total_length
},
type_per_edge_
.
value
().
options
());
}
// Pre-allocate tensors for each node. Because the pick
// functions are modified, this part of code needed refactoring
// to adapt to the change of APIs. It's temporary since the
// whole process will be rewritten soon.
int64_t
allocate_size
=
num_pick_fn
(
offset
,
num_neighbors
);
picked_neighbors_cur_thread
[
i
-
begin
]
=
torch
::
empty
({
allocate_size
},
indptr_options
);
torch
::
Tensor
&
picked_tensor
=
picked_neighbors_cur_thread
[
i
-
begin
];
AT_DISPATCH_INTEGRAL_TYPES
(
picked_tensor
.
scalar_type
(),
"CallPick"
,
([
&
]
{
pick_fn
(
offset
,
num_neighbors
,
picked_tensor
.
data_ptr
<
scalar_t
>
());
}));
num_picked_neighbors_per_node
[
i
+
1
]
=
allocate_size
;
// Step 4. Pick neighbors for each node.
auto
picked_eids_data_ptr
=
picked_eids
.
data_ptr
<
scalar_t
>
();
auto
subgraph_indptr_data_ptr
=
subgraph_indptr
.
data_ptr
<
scalar_t
>
();
torch
::
parallel_for
(
0
,
num_nodes
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
const
auto
nid
=
nodes_data_ptr
[
i
];
const
auto
offset
=
indptr_data
[
nid
];
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
const
auto
picked_number
=
num_picked_neighbors_data_ptr
[
i
+
1
];
const
auto
picked_offset
=
subgraph_indptr_data_ptr
[
i
];
if
(
picked_number
>
0
)
{
auto
actual_picked_count
=
pick_fn
(
offset
,
num_neighbors
,
picked_eids_data_ptr
+
picked_offset
);
TORCH_CHECK
(
actual_picked_count
==
picked_number
,
"Actual picked count doesn't match the calculated pick "
"number."
);
// Step 5. Calculate other attributes and return the subgraph.
AT_DISPATCH_INTEGRAL_TYPES
(
subgraph_indices
.
scalar_type
(),
"IndexSelectSubgraphIndices"
,
([
&
]
{
auto
subgraph_indices_data_ptr
=
subgraph_indices
.
data_ptr
<
scalar_t
>
();
auto
indices_data_ptr
=
indices_
.
data_ptr
<
scalar_t
>
();
for
(
auto
i
=
picked_offset
;
i
<
picked_offset
+
picked_number
;
++
i
)
{
subgraph_indices_data_ptr
[
i
]
=
indices_data_ptr
[
picked_eids_data_ptr
[
i
]];
}
}));
if
(
type_per_edge_
.
has_value
())
{
AT_DISPATCH_INTEGRAL_TYPES
(
subgraph_type_per_edge
.
value
().
scalar_type
(),
"IndexSelectTypePerEdge"
,
([
&
]
{
auto
subgraph_type_per_edge_data_ptr
=
subgraph_type_per_edge
.
value
()
.
data_ptr
<
scalar_t
>
();
auto
type_per_edge_data_ptr
=
type_per_edge_
.
value
().
data_ptr
<
scalar_t
>
();
for
(
auto
i
=
picked_offset
;
i
<
picked_offset
+
picked_number
;
++
i
)
{
subgraph_type_per_edge_data_ptr
[
i
]
=
type_per_edge_data_ptr
[
picked_eids_data_ptr
[
i
]];
}
}));
}
}
}
picked_neighbors_per_thread
[
thread_id
]
=
torch
::
cat
(
picked_neighbors_cur_thread
);
});
// End of parallel_for.
});
}));
torch
::
Tensor
subgraph_indptr
=
torch
::
cumsum
(
num_picked_neighbors_per_node
,
0
);
torch
::
Tensor
picked_eids
=
torch
::
cat
(
picked_neighbors_per_thread
);
torch
::
Tensor
subgraph_indices
=
torch
::
index_select
(
indices_
,
0
,
picked_eids
);
torch
::
optional
<
torch
::
Tensor
>
subgraph_type_per_edge
=
torch
::
nullopt
;
if
(
type_per_edge_
.
has_value
())
{
subgraph_type_per_edge
=
torch
::
index_select
(
type_per_edge_
.
value
(),
0
,
picked_eids
);
}
torch
::
optional
<
torch
::
Tensor
>
subgraph_reverse_edge_ids
=
torch
::
nullopt
;
if
(
return_eids
)
subgraph_reverse_edge_ids
=
std
::
move
(
picked_eids
);
return
c10
::
make_intrusive
<
SampledSubgraph
>
(
subgraph_indptr
,
subgraph_indices
,
nodes
,
torch
::
nullopt
,
subgraph_reverse_edge_ids
,
subgraph_type_per_edge
);
...
...
@@ -383,12 +418,17 @@ int64_t NumPick(
int64_t
fanout
,
bool
replace
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
int64_t
offset
,
int64_t
num_neighbors
)
{
int64_t
num_valid_neighbors
=
probs_or_mask
.
has_value
()
?
*
torch
::
count_nonzero
(
probs_or_mask
.
value
().
slice
(
0
,
offset
,
offset
+
num_neighbors
))
.
data_ptr
<
int64_t
>
()
:
num_neighbors
;
int64_t
num_valid_neighbors
=
num_neighbors
;
if
(
probs_or_mask
.
has_value
())
{
// Subtract the count of zeros in probs_or_mask.
AT_DISPATCH_ALL_TYPES
(
probs_or_mask
.
value
().
scalar_type
(),
"CountZero"
,
([
&
]
{
scalar_t
*
probs_data_ptr
=
probs_or_mask
.
value
().
data_ptr
<
scalar_t
>
();
num_valid_neighbors
-=
std
::
count
(
probs_data_ptr
+
offset
,
probs_data_ptr
+
offset
+
num_neighbors
,
0
);
}));
}
if
(
num_valid_neighbors
==
0
||
fanout
==
-
1
)
return
num_valid_neighbors
;
return
replace
?
fanout
:
std
::
min
(
fanout
,
num_valid_neighbors
);
}
...
...
@@ -444,17 +484,19 @@ int64_t NumPickByEtype(
* should be put. Enough memory space should be allocated in advance.
*/
template
<
typename
PickedType
>
inline
void
UniformPick
(
inline
int64_t
UniformPick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
PickedType
*
picked_data_ptr
)
{
if
((
fanout
==
-
1
)
||
(
num_neighbors
<=
fanout
&&
!
replace
))
{
std
::
iota
(
picked_data_ptr
,
picked_data_ptr
+
num_neighbors
,
offset
);
return
num_neighbors
;
}
else
if
(
replace
)
{
std
::
memcpy
(
picked_data_ptr
,
torch
::
randint
(
offset
,
offset
+
num_neighbors
,
{
fanout
},
options
)
.
data_ptr
<
PickedType
>
(),
fanout
*
sizeof
(
PickedType
));
return
fanout
;
}
else
{
// We use different sampling strategies for different sampling case.
if
(
fanout
>=
num_neighbors
/
10
)
{
...
...
@@ -490,6 +532,7 @@ inline void UniformPick(
}
// Save the randomly sampled fanout elements to the output tensor.
std
::
copy
(
seq
.
begin
(),
seq
.
begin
()
+
fanout
,
picked_data_ptr
);
return
fanout
;
}
else
if
(
fanout
<
64
)
{
// [Algorithm]
// Use linear search to verify uniqueness.
...
...
@@ -510,6 +553,7 @@ inline void UniformPick(
auto
it
=
std
::
find
(
picked_data_ptr
,
begin
,
*
begin
);
if
(
it
==
begin
)
++
begin
;
}
return
fanout
;
}
else
{
// [Algorithm]
// Use hash-set to verify uniqueness. In the best scenario, the
...
...
@@ -533,6 +577,7 @@ inline void UniformPick(
offset
,
offset
+
num_neighbors
));
}
std
::
copy
(
picked_set
.
begin
(),
picked_set
.
end
(),
picked_data_ptr
);
return
picked_set
.
size
();
}
}
}
...
...
@@ -573,7 +618,7 @@ inline void UniformPick(
* should be put. Enough memory space should be allocated in advance.
*/
template
<
typename
PickedType
>
inline
void
NonUniformPick
(
inline
int64_t
NonUniformPick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
...
...
@@ -582,12 +627,13 @@ inline void NonUniformPick(
probs_or_mask
.
value
().
slice
(
0
,
offset
,
offset
+
num_neighbors
);
auto
positive_probs_indices
=
local_probs
.
nonzero
().
squeeze
(
1
);
auto
num_positive_probs
=
positive_probs_indices
.
size
(
0
);
if
(
num_positive_probs
==
0
)
return
;
if
(
num_positive_probs
==
0
)
return
0
;
if
((
fanout
==
-
1
)
||
(
num_positive_probs
<=
fanout
&&
!
replace
))
{
std
::
memcpy
(
picked_data_ptr
,
(
positive_probs_indices
+
offset
).
data_ptr
<
PickedType
>
(),
num_positive_probs
*
sizeof
(
PickedType
));
return
num_positive_probs
;
}
else
{
if
(
!
replace
)
fanout
=
std
::
min
(
fanout
,
num_positive_probs
);
std
::
memcpy
(
...
...
@@ -595,27 +641,28 @@ inline void NonUniformPick(
(
torch
::
multinomial
(
local_probs
,
fanout
,
replace
)
+
offset
)
.
data_ptr
<
PickedType
>
(),
fanout
*
sizeof
(
PickedType
));
return
fanout
;
}
}
template
<
typename
PickedType
>
void
Pick
(
int64_t
Pick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
,
PickedType
*
picked_data_ptr
)
{
if
(
probs_or_mask
.
has_value
())
{
NonUniformPick
(
return
NonUniformPick
(
offset
,
num_neighbors
,
fanout
,
replace
,
options
,
probs_or_mask
,
picked_data_ptr
);
}
else
{
UniformPick
(
return
UniformPick
(
offset
,
num_neighbors
,
fanout
,
replace
,
options
,
picked_data_ptr
);
}
}
template
<
SamplerType
S
,
typename
PickedType
>
void
PickByEtype
(
int64_t
PickByEtype
(
int64_t
offset
,
int64_t
num_neighbors
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
Tensor
&
type_per_edge
,
...
...
@@ -623,11 +670,11 @@ void PickByEtype(
PickedType
*
picked_data_ptr
)
{
int64_t
etype_begin
=
offset
;
int64_t
etype_end
=
offset
;
int64_t
pick_offset
=
0
;
AT_DISPATCH_INTEGRAL_TYPES
(
type_per_edge
.
scalar_type
(),
"PickByEtype"
,
([
&
]
{
const
scalar_t
*
type_per_edge_data
=
type_per_edge
.
data_ptr
<
scalar_t
>
();
const
auto
end
=
offset
+
num_neighbors
;
int64_t
pick_offset
=
0
;
while
(
etype_begin
<
end
)
{
scalar_t
etype
=
type_per_edge_data
[
etype_begin
];
TORCH_CHECK
(
...
...
@@ -638,12 +685,9 @@ void PickByEtype(
type_per_edge_data
+
etype_begin
,
type_per_edge_data
+
end
,
etype
);
etype_end
=
etype_end_it
-
type_per_edge_data
;
int64_t
picked_count
=
NumPick
(
fanout
,
replace
,
probs_or_mask
,
etype_begin
,
etype_end
-
etype_begin
);
// Do sampling for one etype.
if
(
fanout
!=
0
)
{
Pick
(
int64_t
picked_count
=
Pick
(
etype_begin
,
etype_end
-
etype_begin
,
fanout
,
replace
,
options
,
probs_or_mask
,
args
,
picked_data_ptr
+
pick_offset
);
pick_offset
+=
picked_count
;
...
...
@@ -651,43 +695,46 @@ void PickByEtype(
etype_begin
=
etype_end
;
}
}));
return
pick_offset
;
}
template
<
typename
PickedType
>
void
Pick
(
int64_t
Pick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
SamplerType
::
LABOR
>
args
,
PickedType
*
picked_data_ptr
)
{
if
(
fanout
==
0
)
return
;
if
(
fanout
==
0
)
return
0
;
if
(
probs_or_mask
.
has_value
())
{
if
(
fanout
<
0
)
{
NonUniformPick
(
return
NonUniformPick
(
offset
,
num_neighbors
,
fanout
,
replace
,
options
,
probs_or_mask
,
picked_data_ptr
);
}
else
{
int64_t
picked_count
;
AT_DISPATCH_FLOATING_TYPES
(
probs_or_mask
.
value
().
scalar_type
(),
"LaborPickFloatType"
,
([
&
]
{
if
(
replace
)
{
LaborPick
<
true
,
true
,
scalar_t
>
(
picked_count
=
LaborPick
<
true
,
true
,
scalar_t
>
(
offset
,
num_neighbors
,
fanout
,
options
,
probs_or_mask
,
args
,
picked_data_ptr
);
}
else
{
LaborPick
<
true
,
false
,
scalar_t
>
(
picked_count
=
LaborPick
<
true
,
false
,
scalar_t
>
(
offset
,
num_neighbors
,
fanout
,
options
,
probs_or_mask
,
args
,
picked_data_ptr
);
}
}));
return
picked_count
;
}
}
else
if
(
fanout
<
0
)
{
UniformPick
(
return
UniformPick
(
offset
,
num_neighbors
,
fanout
,
replace
,
options
,
picked_data_ptr
);
}
else
if
(
replace
)
{
LaborPick
<
false
,
true
>
(
return
LaborPick
<
false
,
true
>
(
offset
,
num_neighbors
,
fanout
,
options
,
/* probs_or_mask= */
torch
::
nullopt
,
args
,
picked_data_ptr
);
}
else
{
// replace = false
LaborPick
<
false
,
false
>
(
return
LaborPick
<
false
,
false
>
(
offset
,
num_neighbors
,
fanout
,
options
,
/* probs_or_mask= */
torch
::
nullopt
,
args
,
picked_data_ptr
);
}
...
...
@@ -724,7 +771,7 @@ inline void safe_divide(T& a, U b) {
*/
template
<
bool
NonUniform
,
bool
Replace
,
typename
ProbsType
,
typename
PickedType
>
inline
void
LaborPick
(
inline
int64_t
LaborPick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
...
...
@@ -732,7 +779,7 @@ inline void LaborPick(
fanout
=
Replace
?
fanout
:
std
::
min
(
fanout
,
num_neighbors
);
if
(
!
NonUniform
&&
!
Replace
&&
fanout
>=
num_neighbors
)
{
std
::
iota
(
picked_data_ptr
,
picked_data_ptr
+
num_neighbors
,
offset
);
return
;
return
num_neighbors
;
}
torch
::
Tensor
heap_tensor
=
torch
::
empty
({
fanout
*
2
},
torch
::
kInt32
);
// Assuming max_degree of a vertex is <= 4 billion.
...
...
@@ -862,6 +909,7 @@ inline void LaborPick(
TORCH_CHECK
(
!
Replace
||
num_sampled
==
fanout
||
num_sampled
==
0
,
"Sampling with replacement should sample exactly fanout neighbors or 0!"
);
return
num_sampled
;
}
}
// namespace sampling
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment