Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
bbc8ff62
Unverified
Commit
bbc8ff62
authored
Aug 31, 2023
by
Ramon Zhou
Committed by
GitHub
Aug 31, 2023
Browse files
[Graphbolt] Rewrite torch::multinomial to improve sampling performance (#6217)
parent
219c9f1a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
110 additions
and
5 deletions
+110
-5
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+91
-5
graphbolt/src/random.h
graphbolt/src/random.h
+19
-0
No files found.
graphbolt/src/csc_sampling_graph.cc
View file @
bbc8ff62
...
...
@@ -403,6 +403,12 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
probs_or_mask
.
value
().
dtype
()
==
torch
::
kFloat16
)
{
probs_or_mask
=
probs_or_mask
.
value
().
to
(
torch
::
kFloat32
);
}
TORCH_CHECK
(
((
probs_or_mask
.
value
().
max
()
<
INFINITY
)
&
(
probs_or_mask
.
value
().
min
()
>=
0
))
.
item
()
.
to
<
bool
>
(),
"Invalid probs_or_mask (contains either `inf`, `nan` or element < 0)."
);
}
if
(
layer
)
{
...
...
@@ -690,11 +696,91 @@ inline int64_t NonUniformPick(
return
num_positive_probs
;
}
else
{
if
(
!
replace
)
fanout
=
std
::
min
(
fanout
,
num_positive_probs
);
std
::
memcpy
(
picked_data_ptr
,
(
torch
::
multinomial
(
local_probs
,
fanout
,
replace
)
+
offset
)
.
data_ptr
<
PickedType
>
(),
fanout
*
sizeof
(
PickedType
));
if
(
fanout
==
0
)
return
0
;
AT_DISPATCH_FLOATING_TYPES
(
local_probs
.
scalar_type
(),
"MultinomialSampling"
,
([
&
]
{
auto
local_probs_data_ptr
=
local_probs
.
data_ptr
<
scalar_t
>
();
auto
positive_probs_indices_ptr
=
positive_probs_indices
.
data_ptr
<
PickedType
>
();
if
(
!
replace
)
{
// The algorithm is from gumbel softmax.
// s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1).
// Here we can apply exp to the formula which will not affect result
// of argmax or topk. Then we have
// s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1).
// We can also simplify the formula above by
// s = argmax( p / q ) where q ~ Exp(1).
if
(
fanout
==
1
)
{
// Return argmax(p / q).
scalar_t
max_prob
=
0
;
PickedType
max_prob_index
=
-
1
;
// We only care about the neighbors with non-zero probability.
for
(
auto
i
=
0
;
i
<
num_positive_probs
;
++
i
)
{
// Calculate (p / q) for the current neighbor.
scalar_t
current_prob
=
local_probs_data_ptr
[
positive_probs_indices_ptr
[
i
]]
/
RandomEngine
::
ThreadLocal
()
->
Exponential
(
1.
);
if
(
current_prob
>
max_prob
)
{
max_prob
=
current_prob
;
max_prob_index
=
positive_probs_indices_ptr
[
i
];
}
}
*
picked_data_ptr
=
max_prob_index
+
offset
;
}
else
{
// Return topk(p / q).
std
::
vector
<
std
::
pair
<
scalar_t
,
PickedType
>>
q
(
num_positive_probs
);
for
(
auto
i
=
0
;
i
<
num_positive_probs
;
++
i
)
{
q
[
i
].
first
=
local_probs_data_ptr
[
positive_probs_indices_ptr
[
i
]]
/
RandomEngine
::
ThreadLocal
()
->
Exponential
(
1.
);
q
[
i
].
second
=
positive_probs_indices_ptr
[
i
];
}
if
(
fanout
<
num_positive_probs
/
64
)
{
// Use partial_sort.
std
::
partial_sort
(
q
.
begin
(),
q
.
begin
()
+
fanout
,
q
.
end
(),
std
::
greater
{});
for
(
auto
i
=
0
;
i
<
fanout
;
++
i
)
{
picked_data_ptr
[
i
]
=
q
[
i
].
second
+
offset
;
}
}
else
{
// Use nth_element.
std
::
nth_element
(
q
.
begin
(),
q
.
begin
()
+
fanout
-
1
,
q
.
end
(),
std
::
greater
{});
for
(
auto
i
=
0
;
i
<
fanout
;
++
i
)
{
picked_data_ptr
[
i
]
=
q
[
i
].
second
+
offset
;
}
}
}
}
else
{
// Calculate cumulative sum of probabilities.
std
::
vector
<
scalar_t
>
prefix_sum_probs
(
num_positive_probs
);
scalar_t
sum_probs
=
0
;
for
(
auto
i
=
0
;
i
<
num_positive_probs
;
++
i
)
{
sum_probs
+=
local_probs_data_ptr
[
positive_probs_indices_ptr
[
i
]];
prefix_sum_probs
[
i
]
=
sum_probs
;
}
// Normalize.
if
((
sum_probs
>
1.00001
)
||
(
sum_probs
<
0.99999
))
{
for
(
auto
i
=
0
;
i
<
num_positive_probs
;
++
i
)
{
prefix_sum_probs
[
i
]
/=
sum_probs
;
}
}
for
(
auto
i
=
0
;
i
<
fanout
;
++
i
)
{
// Sample a probability mass from a uniform distribution.
double
uniform_sample
=
RandomEngine
::
ThreadLocal
()
->
Uniform
(
0.
,
1.
);
// Use a binary search to find the index.
int
sampled_index
=
std
::
lower_bound
(
prefix_sum_probs
.
begin
(),
prefix_sum_probs
.
end
(),
uniform_sample
)
-
prefix_sum_probs
.
begin
();
picked_data_ptr
[
i
]
=
positive_probs_indices_ptr
[
sampled_index
]
+
offset
;
}
}
}));
return
fanout
;
}
}
...
...
graphbolt/src/random.h
View file @
bbc8ff62
...
...
@@ -69,6 +69,25 @@ class RandomEngine {
return
dist
(
rng_
);
}
/**
* @brief Generate a uniform random real number in [low, high).
*/
template
<
typename
T
>
T
Uniform
(
T
lower
,
T
upper
)
{
std
::
uniform_real_distribution
<
T
>
dist
(
lower
,
upper
);
return
dist
(
rng_
);
}
/**
* @brief Generate random non-negative floating-point values according to
* exponential distribution. Probability density function: P(x|λ) = λe^(-λx).
*/
template
<
typename
T
>
T
Exponential
(
T
lambda
)
{
std
::
exponential_distribution
<
T
>
dist
(
lambda
);
return
dist
(
rng_
);
}
private:
pcg32
rng_
;
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment