Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
11bdd6e8
Unverified
Commit
11bdd6e8
authored
Dec 19, 2023
by
czkkkkkk
Committed by
GitHub
Dec 19, 2023
Browse files
[Graphbolt] Refactor the nonuniform pick function to make it reusable. (#6772)
parent
3d657dbf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
103 additions
and
97 deletions
+103
-97
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+103
-97
No files found.
graphbolt/src/fused_csc_sampling_graph.cc
View file @
11bdd6e8
...
...
@@ -818,6 +818,103 @@ inline int64_t UniformPick(
}
}
/** @brief An operator to perform non-uniform sampling. */
static
torch
::
Tensor
NonUniformPickOp
(
torch
::
Tensor
probs
,
int64_t
fanout
,
bool
replace
)
{
auto
positive_probs_indices
=
probs
.
nonzero
().
squeeze
(
1
);
auto
num_positive_probs
=
positive_probs_indices
.
size
(
0
);
if
(
num_positive_probs
==
0
)
return
torch
::
empty
({
0
},
torch
::
kLong
);
if
((
fanout
==
-
1
)
||
(
num_positive_probs
<=
fanout
&&
!
replace
))
{
return
positive_probs_indices
;
}
if
(
!
replace
)
fanout
=
std
::
min
(
fanout
,
num_positive_probs
);
if
(
fanout
==
0
)
return
torch
::
empty
({
0
},
torch
::
kLong
);
auto
ret_tensor
=
torch
::
empty
({
fanout
},
torch
::
kLong
);
auto
ret_ptr
=
ret_tensor
.
data_ptr
<
int64_t
>
();
AT_DISPATCH_FLOATING_TYPES
(
probs
.
scalar_type
(),
"MultinomialSampling"
,
([
&
]
{
auto
probs_data_ptr
=
probs
.
data_ptr
<
scalar_t
>
();
auto
positive_probs_indices_ptr
=
positive_probs_indices
.
data_ptr
<
int64_t
>
();
if
(
!
replace
)
{
// The algorithm is from gumbel softmax.
// s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1).
// Here we can apply exp to the formula which will not affect result
// of argmax or topk. Then we have
// s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1).
// We can also simplify the formula above by
// s = argmax( p / q ) where q ~ Exp(1).
if
(
fanout
==
1
)
{
// Return argmax(p / q).
scalar_t
max_prob
=
0
;
int64_t
max_prob_index
=
-
1
;
// We only care about the neighbors with non-zero probability.
for
(
auto
i
=
0
;
i
<
num_positive_probs
;
++
i
)
{
// Calculate (p / q) for the current neighbor.
scalar_t
current_prob
=
probs_data_ptr
[
positive_probs_indices_ptr
[
i
]]
/
RandomEngine
::
ThreadLocal
()
->
Exponential
(
1.
);
if
(
current_prob
>
max_prob
)
{
max_prob
=
current_prob
;
max_prob_index
=
positive_probs_indices_ptr
[
i
];
}
}
ret_ptr
[
0
]
=
max_prob_index
;
}
else
{
// Return topk(p / q).
std
::
vector
<
std
::
pair
<
scalar_t
,
int64_t
>>
q
(
num_positive_probs
);
for
(
auto
i
=
0
;
i
<
num_positive_probs
;
++
i
)
{
q
[
i
].
first
=
probs_data_ptr
[
positive_probs_indices_ptr
[
i
]]
/
RandomEngine
::
ThreadLocal
()
->
Exponential
(
1.
);
q
[
i
].
second
=
positive_probs_indices_ptr
[
i
];
}
if
(
fanout
<
num_positive_probs
/
64
)
{
// Use partial_sort.
std
::
partial_sort
(
q
.
begin
(),
q
.
begin
()
+
fanout
,
q
.
end
(),
std
::
greater
{});
for
(
auto
i
=
0
;
i
<
fanout
;
++
i
)
{
ret_ptr
[
i
]
=
q
[
i
].
second
;
}
}
else
{
// Use nth_element.
std
::
nth_element
(
q
.
begin
(),
q
.
begin
()
+
fanout
-
1
,
q
.
end
(),
std
::
greater
{});
for
(
auto
i
=
0
;
i
<
fanout
;
++
i
)
{
ret_ptr
[
i
]
=
q
[
i
].
second
;
}
}
}
}
else
{
// Calculate cumulative sum of probabilities.
std
::
vector
<
scalar_t
>
prefix_sum_probs
(
num_positive_probs
);
scalar_t
sum_probs
=
0
;
for
(
auto
i
=
0
;
i
<
num_positive_probs
;
++
i
)
{
sum_probs
+=
probs_data_ptr
[
positive_probs_indices_ptr
[
i
]];
prefix_sum_probs
[
i
]
=
sum_probs
;
}
// Normalize.
if
((
sum_probs
>
1.00001
)
||
(
sum_probs
<
0.99999
))
{
for
(
auto
i
=
0
;
i
<
num_positive_probs
;
++
i
)
{
prefix_sum_probs
[
i
]
/=
sum_probs
;
}
}
for
(
auto
i
=
0
;
i
<
fanout
;
++
i
)
{
// Sample a probability mass from a uniform distribution.
double
uniform_sample
=
RandomEngine
::
ThreadLocal
()
->
Uniform
(
0.
,
1.
);
// Use a binary search to find the index.
int
sampled_index
=
std
::
lower_bound
(
prefix_sum_probs
.
begin
(),
prefix_sum_probs
.
end
(),
uniform_sample
)
-
prefix_sum_probs
.
begin
();
ret_ptr
[
i
]
=
positive_probs_indices_ptr
[
sampled_index
];
}
}
}));
return
ret_tensor
;
}
/**
* @brief Perform non-uniform sampling of elements based on probabilities and
* return the sampled indices.
...
...
@@ -861,104 +958,13 @@ inline int64_t NonUniformPick(
PickedType
*
picked_data_ptr
)
{
auto
local_probs
=
probs_or_mask
.
value
().
slice
(
0
,
offset
,
offset
+
num_neighbors
);
auto
positive_probs_indices
=
local_probs
.
nonzero
().
squeeze
(
1
);
auto
num_positive_probs
=
positive_probs_indices
.
size
(
0
);
if
(
num_positive_probs
==
0
)
return
0
;
if
((
fanout
==
-
1
)
||
(
num_positive_probs
<=
fanout
&&
!
replace
))
{
std
::
memcpy
(
picked_data_ptr
,
(
positive_probs_indices
+
offset
).
data_ptr
<
PickedType
>
(),
num_positive_probs
*
sizeof
(
PickedType
));
return
num_positive_probs
;
}
else
{
if
(
!
replace
)
fanout
=
std
::
min
(
fanout
,
num_positive_probs
);
if
(
fanout
==
0
)
return
0
;
AT_DISPATCH_FLOATING_TYPES
(
local_probs
.
scalar_type
(),
"MultinomialSampling"
,
([
&
]
{
auto
local_probs_data_ptr
=
local_probs
.
data_ptr
<
scalar_t
>
();
auto
positive_probs_indices_ptr
=
positive_probs_indices
.
data_ptr
<
PickedType
>
();
if
(
!
replace
)
{
// The algorithm is from gumbel softmax.
// s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1).
// Here we can apply exp to the formula which will not affect result
// of argmax or topk. Then we have
// s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1).
// We can also simplify the formula above by
// s = argmax( p / q ) where q ~ Exp(1).
if
(
fanout
==
1
)
{
// Return argmax(p / q).
scalar_t
max_prob
=
0
;
PickedType
max_prob_index
=
-
1
;
// We only care about the neighbors with non-zero probability.
for
(
auto
i
=
0
;
i
<
num_positive_probs
;
++
i
)
{
// Calculate (p / q) for the current neighbor.
scalar_t
current_prob
=
local_probs_data_ptr
[
positive_probs_indices_ptr
[
i
]]
/
RandomEngine
::
ThreadLocal
()
->
Exponential
(
1.
);
if
(
current_prob
>
max_prob
)
{
max_prob
=
current_prob
;
max_prob_index
=
positive_probs_indices_ptr
[
i
];
}
}
*
picked_data_ptr
=
max_prob_index
+
offset
;
}
else
{
// Return topk(p / q).
std
::
vector
<
std
::
pair
<
scalar_t
,
PickedType
>>
q
(
num_positive_probs
);
for
(
auto
i
=
0
;
i
<
num_positive_probs
;
++
i
)
{
q
[
i
].
first
=
local_probs_data_ptr
[
positive_probs_indices_ptr
[
i
]]
/
RandomEngine
::
ThreadLocal
()
->
Exponential
(
1.
);
q
[
i
].
second
=
positive_probs_indices_ptr
[
i
];
}
if
(
fanout
<
num_positive_probs
/
64
)
{
// Use partial_sort.
std
::
partial_sort
(
q
.
begin
(),
q
.
begin
()
+
fanout
,
q
.
end
(),
std
::
greater
{});
for
(
auto
i
=
0
;
i
<
fanout
;
++
i
)
{
picked_data_ptr
[
i
]
=
q
[
i
].
second
+
offset
;
}
}
else
{
// Use nth_element.
std
::
nth_element
(
q
.
begin
(),
q
.
begin
()
+
fanout
-
1
,
q
.
end
(),
std
::
greater
{});
for
(
auto
i
=
0
;
i
<
fanout
;
++
i
)
{
picked_data_ptr
[
i
]
=
q
[
i
].
second
+
offset
;
}
}
}
}
else
{
// Calculate cumulative sum of probabilities.
std
::
vector
<
scalar_t
>
prefix_sum_probs
(
num_positive_probs
);
scalar_t
sum_probs
=
0
;
for
(
auto
i
=
0
;
i
<
num_positive_probs
;
++
i
)
{
sum_probs
+=
local_probs_data_ptr
[
positive_probs_indices_ptr
[
i
]];
prefix_sum_probs
[
i
]
=
sum_probs
;
}
// Normalize.
if
((
sum_probs
>
1.00001
)
||
(
sum_probs
<
0.99999
))
{
for
(
auto
i
=
0
;
i
<
num_positive_probs
;
++
i
)
{
prefix_sum_probs
[
i
]
/=
sum_probs
;
}
}
for
(
auto
i
=
0
;
i
<
fanout
;
++
i
)
{
// Sample a probability mass from a uniform distribution.
double
uniform_sample
=
RandomEngine
::
ThreadLocal
()
->
Uniform
(
0.
,
1.
);
// Use a binary search to find the index.
int
sampled_index
=
std
::
lower_bound
(
prefix_sum_probs
.
begin
(),
prefix_sum_probs
.
end
(),
uniform_sample
)
-
prefix_sum_probs
.
begin
();
picked_data_ptr
[
i
]
=
positive_probs_indices_ptr
[
sampled_index
]
+
offset
;
}
}
}));
return
fanout
;
auto
picked_indices
=
NonUniformPickOp
(
local_probs
,
fanout
,
replace
);
auto
picked_indices_ptr
=
picked_indices
.
data_ptr
<
int64_t
>
();
for
(
int
i
=
0
;
i
<
picked_indices
.
numel
();
++
i
)
{
picked_data_ptr
[
i
]
=
static_cast
<
PickedType
>
(
picked_indices_ptr
[
i
])
+
offset
;
}
return
picked_indices
.
numel
();
}
template
<
typename
PickedType
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment