Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
58d98e03
Unverified
Commit
58d98e03
authored
Aug 14, 2023
by
Ramon Zhou
Committed by
GitHub
Aug 14, 2023
Browse files
[Graphbolt] Utilize pre-allocation in sampling (#6132)
parent
f0d8ca1e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
214 additions
and
205 deletions
+214
-205
graphbolt/include/graphbolt/csc_sampling_graph.h
graphbolt/include/graphbolt/csc_sampling_graph.h
+19
-22
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+195
-183
No files found.
graphbolt/include/graphbolt/csc_sampling_graph.h
View file @
58d98e03
...
...
@@ -353,28 +353,22 @@ int64_t NumPickByEtype(
* probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph.
*
*
@return A tensor containing the picked neighbors
.
*
@param picked_data_ptr The destination address where the picked neighbors
*
should be put. Enough memory space should be allocated in advance
.
*/
template
<
SamplerType
S
>
torch
::
Tensor
Pick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
S
>
args
);
template
<
>
torch
::
Tensor
Pick
<
SamplerType
::
NEIGHBOR
>
(
template
<
typename
PickedType
>
void
Pick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
);
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
,
PickedType
*
picked_data_ptr
);
template
<
>
torch
::
Tensor
Pick
<
SamplerType
::
LABOR
>
(
template
<
typename
PickedType
>
void
Pick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
SamplerType
::
LABOR
>
args
);
SamplerArgs
<
SamplerType
::
LABOR
>
args
,
PickedType
*
picked_data_ptr
);
/**
* @brief Picks a specified number of neighbors for a node per edge type,
...
...
@@ -400,22 +394,25 @@ torch::Tensor Pick<SamplerType::LABOR>(
* probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph.
*
*
@return A tensor containing the picked neighbors
.
*
@param picked_data_ptr The destination address where the picked neighbors
*
should be put. Enough memory space should be allocated in advance
.
*/
template
<
SamplerType
S
>
torch
::
Tensor
PickByEtype
(
template
<
SamplerType
S
,
typename
PickedType
>
void
PickByEtype
(
int64_t
offset
,
int64_t
num_neighbors
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
Tensor
&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
S
>
args
);
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
S
>
args
,
PickedType
*
picked_data_ptr
);
template
<
bool
NonUniform
,
bool
Replace
,
typename
T
=
float
>
torch
::
Tensor
LaborPick
(
template
<
bool
NonUniform
,
bool
Replace
,
typename
ProbsType
=
float
,
typename
PickedType
>
void
LaborPick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
SamplerType
::
LABOR
>
args
);
SamplerArgs
<
SamplerType
::
LABOR
>
args
,
PickedType
*
picked_data_ptr
);
}
// namespace sampling
}
// namespace graphbolt
...
...
graphbolt/src/csc_sampling_graph.cc
View file @
58d98e03
...
...
@@ -10,6 +10,7 @@
#include <cmath>
#include <limits>
#include <numeric>
#include <tuple>
#include <vector>
...
...
@@ -187,10 +188,11 @@ auto GetNumPickFn(
* equal to the number of edges in the graph.
* @param args Contains sampling algorithm specific arguments.
*
* @return A lambda function: (int64_t offset, int64_t num_neighbors) ->
* torch::Tensor, which takes offset (the starting edge ID of the given node)
* and num_neighbors (number of neighbors) as params and returns a tensor of
* picked neighbors.
* @return A lambda function: (int64_t offset, int64_t num_neighbors,
* PickedType* picked_data_ptr) -> torch::Tensor, which takes offset (the
* starting edge ID of the given node) and num_neighbors (number of neighbors)
* as params and puts the picked neighbors at the address specified by
* picked_data_ptr.
*/
template
<
SamplerType
S
>
auto
GetPickFn
(
...
...
@@ -199,17 +201,18 @@ auto GetPickFn(
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
S
>
args
)
{
return
[
&
fanouts
,
replace
,
&
options
,
&
type_per_edge
,
&
probs_or_mask
,
args
](
int64_t
offset
,
int64_t
num_neighbors
)
{
// If fanouts.size() > 1, perform sampling for each edge type of each node;
// otherwise just sample once for each node with no regard of edge types.
int64_t
offset
,
int64_t
num_neighbors
,
auto
picked_data_ptr
)
{
// If fanouts.size() > 1, perform sampling for each edge type of each
// node; otherwise just sample once for each node with no regard of edge
// types.
if
(
fanouts
.
size
()
>
1
)
{
return
PickByEtype
(
offset
,
num_neighbors
,
fanouts
,
replace
,
options
,
type_per_edge
.
value
(),
probs_or_mask
,
args
);
type_per_edge
.
value
(),
probs_or_mask
,
args
,
picked_data_ptr
);
}
else
{
return
Pick
(
offset
,
num_neighbors
,
fanouts
[
0
],
replace
,
options
,
probs_or_mask
,
args
);
args
,
picked_data_ptr
);
}
};
}
...
...
@@ -257,17 +260,23 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
continue
;
}
// Pre-allocate tensors for each node. Because the pick
// functions are modified, this part of code needed refactoring
// to adapt to the change of APIs. It's temporary since the
// whole process will be rewritten soon.
int64_t
allocate_size
=
num_pick_fn
(
offset
,
num_neighbors
);
picked_neighbors_cur_thread
[
i
-
begin
]
=
pick_fn
(
offset
,
num_neighbors
);
// This number should be the same as the result of num_pick_fn.
num_picked_neighbors_per_node
[
i
+
1
]
=
picked_neighbors_cur_thread
[
i
-
begin
].
size
(
0
);
TORCH_CHECK
(
*
num_picked_neighbors_per_node
[
i
+
1
].
data_ptr
<
int64_t
>
()
==
num_pick_fn
(
offset
,
num_neighbors
),
"Return value of num_pick_fn doesn't match the actual "
"picked number."
);
torch
::
empty
({
allocate_size
},
indptr_options
);
torch
::
Tensor
&
picked_tensor
=
picked_neighbors_cur_thread
[
i
-
begin
];
AT_DISPATCH_INTEGRAL_TYPES
(
picked_tensor
.
scalar_type
(),
"CallPick"
,
([
&
]
{
pick_fn
(
offset
,
num_neighbors
,
picked_tensor
.
data_ptr
<
scalar_t
>
());
}));
num_picked_neighbors_per_node
[
i
+
1
]
=
allocate_size
;
}
picked_neighbors_per_thread
[
thread_id
]
=
torch
::
cat
(
picked_neighbors_cur_thread
);
...
...
@@ -431,106 +440,101 @@ int64_t NumPickByEtype(
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result.
*
*
@return A tensor containing the picked neighbors
.
*
@param picked_data_ptr The destination address where the picked neighbors
*
should be put. Enough memory space should be allocated in advance
.
*/
inline
torch
::
Tensor
UniformPick
(
template
<
typename
PickedType
>
inline
void
UniformPick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
)
{
torch
::
Tensor
picked_neighbors
;
const
torch
::
TensorOptions
&
options
,
PickedType
*
picked_data_ptr
)
{
if
((
fanout
==
-
1
)
||
(
num_neighbors
<=
fanout
&&
!
replace
))
{
picked_neighbors
=
torch
::
arange
(
offset
,
offset
+
num_neighbors
,
o
ptions
);
std
::
iota
(
picked_data_ptr
,
picked_data_ptr
+
num_neighbors
,
o
ffset
);
}
else
if
(
replace
)
{
picked_neighbors
=
torch
::
randint
(
offset
,
offset
+
num_neighbors
,
{
fanout
},
options
);
std
::
memcpy
(
picked_data_ptr
,
torch
::
randint
(
offset
,
offset
+
num_neighbors
,
{
fanout
},
options
)
.
data_ptr
<
PickedType
>
(),
fanout
*
sizeof
(
PickedType
));
}
else
{
picked_neighbors
=
torch
::
empty
({
fanout
},
options
);
AT_DISPATCH_INTEGRAL_TYPES
(
picked_neighbors
.
scalar_type
(),
"UniformPick"
,
([
&
]
{
scalar_t
*
picked_neighbors_data
=
picked_neighbors
.
data_ptr
<
scalar_t
>
();
// We use different sampling strategies for different sampling case.
if
(
fanout
>=
num_neighbors
/
10
)
{
// [Algorithm]
// This algorithm is conceptually related to the Fisher-Yates
// shuffle.
//
// [Complexity Analysis]
// This algorithm's memory complexity is O(num_neighbors), but
// it generates fewer random numbers (O(fanout)).
//
// (Compare) Reservoir algorithm is one of the most classical
// sampling algorithms. Both the reservoir algorithm and our
// algorithm offer distinct advantages, we need to compare to
// illustrate our trade-offs.
// The reservoir algorithm is memory-efficient (O(fanout)) but
// creates many random numbers (O(num_neighbors)), which is
// costly.
//
// [Practical Consideration]
// Use this algorithm when `fanout >= num_neighbors / 10` to
// reduce computation.
// In this scenarios above, memory complexity is not a concern due
// to the small size of both `fanout` and `num_neighbors`. And it
// is efficient to allocate a small amount of memory. So the
// algorithm performence is great in this case.
std
::
vector
<
scalar_t
>
seq
(
num_neighbors
);
// Assign the seq with [offset, offset + num_neighbors].
std
::
iota
(
seq
.
begin
(),
seq
.
end
(),
offset
);
for
(
int64_t
i
=
0
;
i
<
fanout
;
++
i
)
{
auto
j
=
RandomEngine
::
ThreadLocal
()
->
RandInt
(
i
,
num_neighbors
);
std
::
swap
(
seq
[
i
],
seq
[
j
]);
}
// Save the randomly sampled fanout elements to the output tensor.
std
::
copy
(
seq
.
begin
(),
seq
.
begin
()
+
fanout
,
picked_neighbors_data
);
}
else
if
(
fanout
<
64
)
{
// [Algorithm]
// Use linear search to verify uniqueness.
//
// [Complexity Analysis]
// Since the set of numbers is small (up to 64), so it is more
// cost-effective for the CPU to use this algorithm.
auto
begin
=
picked_neighbors_data
;
auto
end
=
picked_neighbors_data
+
fanout
;
while
(
begin
!=
end
)
{
// Put the new random number in the last position.
*
begin
=
RandomEngine
::
ThreadLocal
()
->
RandInt
(
offset
,
offset
+
num_neighbors
);
// Check if a new value doesn't exist in current
// range(picked_neighbors_data, begin). Otherwise get a new
// value until we haven't unique range of elements.
auto
it
=
std
::
find
(
picked_neighbors_data
,
begin
,
*
begin
);
if
(
it
==
begin
)
++
begin
;
}
}
else
{
// [Algorithm]
// Use hash-set to verify uniqueness. In the best scenario, the
// time complexity is O(fanout), assuming no conflicts occur.
//
// [Complexity Analysis]
// Let K = (fanout / num_neighbors), the expected number of extra
// sampling steps is roughly K^2 / (1-K) * num_neighbors, which
// means in the worst case scenario, the time complexity is
// O(num_neighbors^2).
//
// [Practical Consideration]
// In practice, we set the threshold K to 1/10. This trade-off is
// due to the slower performance of std::unordered_set, which
// would otherwise increase the sampling cost. By doing so, we
// achieve a balance between theoretical efficiency and practical
// performance.
std
::
unordered_set
<
scalar_t
>
picked_set
;
while
(
static_cast
<
int64_t
>
(
picked_set
.
size
())
<
fanout
)
{
picked_set
.
insert
(
RandomEngine
::
ThreadLocal
()
->
RandInt
(
offset
,
offset
+
num_neighbors
));
}
std
::
copy
(
picked_set
.
begin
(),
picked_set
.
end
(),
picked_neighbors_data
);
}
}));
// We use different sampling strategies for different sampling case.
if
(
fanout
>=
num_neighbors
/
10
)
{
// [Algorithm]
// This algorithm is conceptually related to the Fisher-Yates
// shuffle.
//
// [Complexity Analysis]
// This algorithm's memory complexity is O(num_neighbors), but
// it generates fewer random numbers (O(fanout)).
//
// (Compare) Reservoir algorithm is one of the most classical
// sampling algorithms. Both the reservoir algorithm and our
// algorithm offer distinct advantages, we need to compare to
// illustrate our trade-offs.
// The reservoir algorithm is memory-efficient (O(fanout)) but
// creates many random numbers (O(num_neighbors)), which is
// costly.
//
// [Practical Consideration]
// Use this algorithm when `fanout >= num_neighbors / 10` to
// reduce computation.
// In this scenarios above, memory complexity is not a concern due
// to the small size of both `fanout` and `num_neighbors`. And it
// is efficient to allocate a small amount of memory. So the
// algorithm performence is great in this case.
std
::
vector
<
PickedType
>
seq
(
num_neighbors
);
// Assign the seq with [offset, offset + num_neighbors].
std
::
iota
(
seq
.
begin
(),
seq
.
end
(),
offset
);
for
(
int64_t
i
=
0
;
i
<
fanout
;
++
i
)
{
auto
j
=
RandomEngine
::
ThreadLocal
()
->
RandInt
(
i
,
num_neighbors
);
std
::
swap
(
seq
[
i
],
seq
[
j
]);
}
// Save the randomly sampled fanout elements to the output tensor.
std
::
copy
(
seq
.
begin
(),
seq
.
begin
()
+
fanout
,
picked_data_ptr
);
}
else
if
(
fanout
<
64
)
{
// [Algorithm]
// Use linear search to verify uniqueness.
//
// [Complexity Analysis]
// Since the set of numbers is small (up to 64), so it is more
// cost-effective for the CPU to use this algorithm.
auto
begin
=
picked_data_ptr
;
auto
end
=
picked_data_ptr
+
fanout
;
while
(
begin
!=
end
)
{
// Put the new random number in the last position.
*
begin
=
RandomEngine
::
ThreadLocal
()
->
RandInt
(
offset
,
offset
+
num_neighbors
);
// Check if a new value doesn't exist in current
// range(picked_data_ptr, begin). Otherwise get a new
// value until we haven't unique range of elements.
auto
it
=
std
::
find
(
picked_data_ptr
,
begin
,
*
begin
);
if
(
it
==
begin
)
++
begin
;
}
}
else
{
// [Algorithm]
// Use hash-set to verify uniqueness. In the best scenario, the
// time complexity is O(fanout), assuming no conflicts occur.
//
// [Complexity Analysis]
// Let K = (fanout / num_neighbors), the expected number of extra
// sampling steps is roughly K^2 / (1-K) * num_neighbors, which
// means in the worst case scenario, the time complexity is
// O(num_neighbors^2).
//
// [Practical Consideration]
// In practice, we set the threshold K to 1/10. This trade-off is
// due to the slower performance of std::unordered_set, which
// would otherwise increase the sampling cost. By doing so, we
// achieve a balance between theoretical efficiency and practical
// performance.
std
::
unordered_set
<
PickedType
>
picked_set
;
while
(
static_cast
<
int64_t
>
(
picked_set
.
size
())
<
fanout
)
{
picked_set
.
insert
(
RandomEngine
::
ThreadLocal
()
->
RandInt
(
offset
,
offset
+
num_neighbors
));
}
std
::
copy
(
picked_set
.
begin
(),
picked_set
.
end
(),
picked_data_ptr
);
}
}
return
picked_neighbors
;
}
/**
...
...
@@ -565,59 +569,65 @@ inline torch::Tensor UniformPick(
* probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph.
*
*
@return A tensor containing the picked neighbors
.
*
@param picked_data_ptr The destination address where the picked neighbors
*
should be put. Enough memory space should be allocated in advance
.
*/
inline
torch
::
Tensor
NonUniformPick
(
template
<
typename
PickedType
>
inline
void
NonUniformPick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
)
{
torch
::
Tensor
picked_neighbors
;
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
PickedType
*
picked_data_ptr
)
{
auto
local_probs
=
probs_or_mask
.
value
().
slice
(
0
,
offset
,
offset
+
num_neighbors
);
auto
positive_probs_indices
=
local_probs
.
nonzero
().
squeeze
(
1
);
auto
num_positive_probs
=
positive_probs_indices
.
size
(
0
);
if
(
num_positive_probs
==
0
)
return
torch
::
tensor
({},
options
)
;
if
(
num_positive_probs
==
0
)
return
;
if
((
fanout
==
-
1
)
||
(
num_positive_probs
<=
fanout
&&
!
replace
))
{
picked_neighbors
=
torch
::
arange
(
offset
,
offset
+
num_neighbors
,
options
);
picked_neighbors
=
torch
::
index_select
(
picked_neighbors
,
0
,
positive_probs_indices
);
std
::
memcpy
(
picked_data_ptr
,
(
positive_probs_indices
+
offset
).
data_ptr
<
PickedType
>
(),
num_positive_probs
*
sizeof
(
PickedType
));
}
else
{
if
(
!
replace
)
fanout
=
std
::
min
(
fanout
,
num_positive_probs
);
picked_neighbors
=
torch
::
multinomial
(
local_probs
,
fanout
,
replace
)
+
offset
;
std
::
memcpy
(
picked_data_ptr
,
(
torch
::
multinomial
(
local_probs
,
fanout
,
replace
)
+
offset
)
.
data_ptr
<
PickedType
>
(),
fanout
*
sizeof
(
PickedType
));
}
return
picked_neighbors
;
}
template
<
>
torch
::
Tensor
Pick
<
SamplerType
::
NEIGHBOR
>
(
template
<
typename
PickedType
>
void
Pick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
)
{
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
,
PickedType
*
picked_data_ptr
)
{
if
(
probs_or_mask
.
has_value
())
{
return
NonUniformPick
(
offset
,
num_neighbors
,
fanout
,
replace
,
options
,
probs_or_mask
);
NonUniformPick
(
offset
,
num_neighbors
,
fanout
,
replace
,
options
,
probs_or_mask
,
picked_data_ptr
);
}
else
{
return
UniformPick
(
offset
,
num_neighbors
,
fanout
,
replace
,
options
);
UniformPick
(
offset
,
num_neighbors
,
fanout
,
replace
,
options
,
picked_data_ptr
);
}
}
template
<
SamplerType
S
>
torch
::
Tensor
PickByEtype
(
template
<
SamplerType
S
,
typename
PickedType
>
void
PickByEtype
(
int64_t
offset
,
int64_t
num_neighbors
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
Tensor
&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
S
>
args
)
{
std
::
vector
<
torch
::
Tensor
>
picked_neighbors
(
fanouts
.
size
(),
torch
::
tensor
({},
options
));
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
S
>
args
,
PickedType
*
picked_data_ptr
)
{
int64_t
etype_begin
=
offset
;
int64_t
etype_end
=
offset
;
AT_DISPATCH_INTEGRAL_TYPES
(
type_per_edge
.
scalar_type
(),
"PickByEtype"
,
([
&
]
{
const
scalar_t
*
type_per_edge_data
=
type_per_edge
.
data_ptr
<
scalar_t
>
();
const
auto
end
=
offset
+
num_neighbors
;
int64_t
pick_offset
=
0
;
while
(
etype_begin
<
end
)
{
scalar_t
etype
=
type_per_edge_data
[
etype_begin
];
TORCH_CHECK
(
...
...
@@ -628,53 +638,58 @@ torch::Tensor PickByEtype(
type_per_edge_data
+
etype_begin
,
type_per_edge_data
+
end
,
etype
);
etype_end
=
etype_end_it
-
type_per_edge_data
;
int64_t
picked_count
=
NumPick
(
fanout
,
replace
,
probs_or_mask
,
etype_begin
,
etype_end
-
etype_begin
);
// Do sampling for one etype.
if
(
fanout
!=
0
)
{
picked_neighbors
[
etype
]
=
Pick
<
S
>
(
Pick
(
etype_begin
,
etype_end
-
etype_begin
,
fanout
,
replace
,
options
,
probs_or_mask
,
args
);
probs_or_mask
,
args
,
picked_data_ptr
+
pick_offset
);
pick_offset
+=
picked_count
;
}
etype_begin
=
etype_end
;
}
}));
return
torch
::
cat
(
picked_neighbors
,
0
);
}
template
<
>
torch
::
Tensor
Pick
<
SamplerType
::
LABOR
>
(
template
<
typename
PickedType
>
void
Pick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
SamplerType
::
LABOR
>
args
)
{
if
(
fanout
==
0
)
return
torch
::
tensor
({},
options
)
;
SamplerArgs
<
SamplerType
::
LABOR
>
args
,
PickedType
*
picked_data_ptr
)
{
if
(
fanout
==
0
)
return
;
if
(
probs_or_mask
.
has_value
())
{
if
(
fanout
<
0
)
{
return
NonUniformPick
(
offset
,
num_neighbors
,
fanout
,
replace
,
options
,
probs_or_mask
);
NonUniformPick
(
offset
,
num_neighbors
,
fanout
,
replace
,
options
,
probs_or_mask
,
picked_data_ptr
);
}
else
{
AT_DISPATCH_FLOATING_TYPES
(
probs_or_mask
.
value
().
scalar_type
(),
"LaborPickFloatType"
,
([
&
]
{
if
(
replace
)
{
LaborPick
<
true
,
true
,
scalar_t
>
(
offset
,
num_neighbors
,
fanout
,
options
,
probs_or_mask
,
args
,
picked_data_ptr
);
}
else
{
LaborPick
<
true
,
false
,
scalar_t
>
(
offset
,
num_neighbors
,
fanout
,
options
,
probs_or_mask
,
args
,
picked_data_ptr
);
}
}));
}
torch
::
Tensor
picked_neighbors
;
AT_DISPATCH_FLOATING_TYPES
(
probs_or_mask
.
value
().
scalar_type
(),
"LaborPickFloatType"
,
([
&
]
{
if
(
replace
)
{
picked_neighbors
=
LaborPick
<
true
,
true
,
scalar_t
>
(
offset
,
num_neighbors
,
fanout
,
options
,
probs_or_mask
,
args
);
}
else
{
picked_neighbors
=
LaborPick
<
true
,
false
,
scalar_t
>
(
offset
,
num_neighbors
,
fanout
,
options
,
probs_or_mask
,
args
);
}
}));
return
picked_neighbors
;
}
else
if
(
fanout
<
0
)
{
return
UniformPick
(
offset
,
num_neighbors
,
fanout
,
replace
,
options
);
UniformPick
(
offset
,
num_neighbors
,
fanout
,
replace
,
options
,
picked_data_ptr
);
}
else
if
(
replace
)
{
return
LaborPick
<
false
,
true
>
(
LaborPick
<
false
,
true
>
(
offset
,
num_neighbors
,
fanout
,
options
,
/* probs_or_mask= */
torch
::
nullopt
,
args
);
/* probs_or_mask= */
torch
::
nullopt
,
args
,
picked_data_ptr
);
}
else
{
// replace = false
return
LaborPick
<
false
,
false
>
(
LaborPick
<
false
,
false
>
(
offset
,
num_neighbors
,
fanout
,
options
,
/* probs_or_mask= */
torch
::
nullopt
,
args
);
/* probs_or_mask= */
torch
::
nullopt
,
args
,
picked_data_ptr
);
}
}
...
...
@@ -704,25 +719,28 @@ inline void safe_divide(T& a, U b) {
* graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph.
* @param args Contains labor specific arguments.
*
*
@return A tensor containing the picked neighbors
.
*
@param picked_data_ptr The destination address where the picked neighbors
*
should be put. Enough memory space should be allocated in advance
.
*/
template
<
bool
NonUniform
,
bool
Replace
,
typename
T
>
inline
torch
::
Tensor
LaborPick
(
template
<
bool
NonUniform
,
bool
Replace
,
typename
ProbsType
,
typename
PickedType
>
inline
void
LaborPick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
SamplerType
::
LABOR
>
args
)
{
SamplerArgs
<
SamplerType
::
LABOR
>
args
,
PickedType
*
picked_data_ptr
)
{
fanout
=
Replace
?
fanout
:
std
::
min
(
fanout
,
num_neighbors
);
if
(
!
NonUniform
&&
!
Replace
&&
fanout
>=
num_neighbors
)
{
return
torch
::
arange
(
offset
,
offset
+
num_neighbors
,
options
);
std
::
iota
(
picked_data_ptr
,
picked_data_ptr
+
num_neighbors
,
offset
);
return
;
}
torch
::
Tensor
heap_tensor
=
torch
::
empty
({
fanout
*
2
},
torch
::
kInt32
);
// Assuming max_degree of a vertex is <= 4 billion.
auto
heap_data
=
reinterpret_cast
<
std
::
pair
<
float
,
uint32_t
>*>
(
heap_tensor
.
data_ptr
<
int32_t
>
());
const
T
*
local_probs_data
=
NonUniform
?
probs_or_mask
.
value
().
data_ptr
<
T
>
()
+
offset
:
nullptr
;
const
ProbsType
*
local_probs_data
=
NonUniform
?
probs_or_mask
.
value
().
data_ptr
<
ProbsType
>
()
+
offset
:
nullptr
;
AT_DISPATCH_INTEGRAL_TYPES
(
args
.
indices
.
scalar_type
(),
"LaborPickMain"
,
([
&
]
{
const
scalar_t
*
local_indices_data
=
...
...
@@ -835,21 +853,15 @@ inline torch::Tensor LaborPick(
}
}));
int64_t
num_sampled
=
0
;
torch
::
Tensor
picked_neighbors
=
torch
::
empty
({
fanout
},
options
);
AT_DISPATCH_INTEGRAL_TYPES
(
picked_neighbors
.
scalar_type
(),
"LaborPickOutput"
,
([
&
]
{
scalar_t
*
picked_neighbors_data
=
picked_neighbors
.
data_ptr
<
scalar_t
>
();
for
(
int64_t
i
=
0
;
i
<
fanout
;
++
i
)
{
const
auto
[
rnd
,
j
]
=
heap_data
[
i
];
if
(
!
NonUniform
||
rnd
<
std
::
numeric_limits
<
float
>::
infinity
())
{
picked_neighbors_data
[
num_sampled
++
]
=
offset
+
j
;
}
}
}));
for
(
int64_t
i
=
0
;
i
<
fanout
;
++
i
)
{
const
auto
[
rnd
,
j
]
=
heap_data
[
i
];
if
(
!
NonUniform
||
rnd
<
std
::
numeric_limits
<
float
>::
infinity
())
{
picked_data_ptr
[
num_sampled
++
]
=
offset
+
j
;
}
}
TORCH_CHECK
(
!
Replace
||
num_sampled
==
fanout
||
num_sampled
==
0
,
"Sampling with replacement should sample exactly fanout neighbors or 0!"
);
return
picked_neighbors
.
narrow
(
0
,
0
,
num_sampled
);
}
}
// namespace sampling
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment