Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
1e6fa711
Unverified
Commit
1e6fa711
authored
Feb 03, 2024
by
czkkkkkk
Committed by
GitHub
Feb 03, 2024
Browse files
[Graphbolt] Add fast path for tamporal sampling. (#7078)
parent
15695ed0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
72 additions
and
0 deletions
+72
-0
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+72
-0
No files found.
graphbolt/src/fused_csc_sampling_graph.cc
View file @
1e6fa711
...
...
@@ -810,12 +810,71 @@ torch::Tensor TemporalMask(
return
mask
;
}
/**
* @brief Fast path for temporal sampling without probability. It is used when
* the number of neighbors is large. It randomly samples neighbors and checks
* the timestamp of the neighbors. It is successful if the number of sampled
* neighbors in kTriedThreshold trials is equal to the fanout.
*/
std
::
pair
<
bool
,
std
::
vector
<
int64_t
>>
FastTemporalPick
(
torch
::
Tensor
seed_timestamp
,
torch
::
Tensor
csc_indices
,
int64_t
fanout
,
bool
replace
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_timestamp
,
const
torch
::
optional
<
torch
::
Tensor
>&
edge_timestamp
,
int64_t
seed_offset
,
int64_t
offset
,
int64_t
num_neighbors
)
{
constexpr
int64_t
kTriedThreshold
=
1000
;
auto
timestamp
=
utils
::
GetValueByIndex
<
int64_t
>
(
seed_timestamp
,
seed_offset
);
std
::
vector
<
int64_t
>
sampled_edges
;
sampled_edges
.
reserve
(
fanout
);
std
::
set
<
int64_t
>
sampled_edge_set
;
int64_t
sample_count
=
0
;
int64_t
tried
=
0
;
while
(
sample_count
<
fanout
&&
tried
<
kTriedThreshold
)
{
int64_t
edge_id
=
RandomEngine
::
ThreadLocal
()
->
RandInt
(
offset
,
offset
+
num_neighbors
);
++
tried
;
if
(
!
replace
&&
sampled_edge_set
.
count
(
edge_id
)
>
0
)
{
continue
;
}
if
(
node_timestamp
.
has_value
())
{
int64_t
neighbor_id
=
utils
::
GetValueByIndex
<
int64_t
>
(
csc_indices
,
edge_id
);
if
(
utils
::
GetValueByIndex
<
int64_t
>
(
node_timestamp
.
value
(),
neighbor_id
)
>=
timestamp
)
continue
;
}
if
(
edge_timestamp
.
has_value
()
&&
utils
::
GetValueByIndex
<
int64_t
>
(
edge_timestamp
.
value
(),
edge_id
)
>=
timestamp
)
{
continue
;
}
if
(
!
replace
)
{
sampled_edge_set
.
insert
(
edge_id
);
}
sampled_edges
.
push_back
(
edge_id
);
sample_count
++
;
}
if
(
sample_count
<
fanout
)
{
return
{
false
,
{}};
}
return
{
true
,
sampled_edges
};
}
int64_t
TemporalNumPick
(
torch
::
Tensor
seed_timestamp
,
torch
::
Tensor
csc_indics
,
int64_t
fanout
,
bool
replace
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_timestamp
,
const
torch
::
optional
<
torch
::
Tensor
>&
edge_timestamp
,
int64_t
seed_offset
,
int64_t
offset
,
int64_t
num_neighbors
)
{
constexpr
int64_t
kFastPathThreshold
=
1000
;
if
(
num_neighbors
>
kFastPathThreshold
&&
!
probs_or_mask
.
has_value
())
{
// TODO: Currently we use the fast path both in TemporalNumPick and
// TemporalPick. We may only sample once in TemporalNumPick and use the
// sampled edges in TemporalPick to avoid sampling twice.
auto
[
success
,
sampled_edges
]
=
FastTemporalPick
(
seed_timestamp
,
csc_indics
,
fanout
,
replace
,
node_timestamp
,
edge_timestamp
,
seed_offset
,
offset
,
num_neighbors
);
if
(
success
)
return
sampled_edges
.
size
();
}
auto
mask
=
TemporalMask
(
utils
::
GetValueByIndex
<
int64_t
>
(
seed_timestamp
,
seed_offset
),
csc_indics
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
,
...
...
@@ -1183,6 +1242,19 @@ int64_t TemporalPick(
const
torch
::
optional
<
torch
::
Tensor
>&
node_timestamp
,
const
torch
::
optional
<
torch
::
Tensor
>&
edge_timestamp
,
SamplerArgs
<
S
>
args
,
PickedType
*
picked_data_ptr
)
{
constexpr
int64_t
kFastPathThreshold
=
1000
;
if
(
S
==
SamplerType
::
NEIGHBOR
&&
num_neighbors
>
kFastPathThreshold
&&
!
probs_or_mask
.
has_value
())
{
auto
[
success
,
sampled_edges
]
=
FastTemporalPick
(
seed_timestamp
,
csc_indices
,
fanout
,
replace
,
node_timestamp
,
edge_timestamp
,
seed_offset
,
offset
,
num_neighbors
);
if
(
success
)
{
for
(
size_t
i
=
0
;
i
<
sampled_edges
.
size
();
++
i
)
{
picked_data_ptr
[
i
]
=
static_cast
<
PickedType
>
(
sampled_edges
[
i
]);
}
return
sampled_edges
.
size
();
}
}
auto
mask
=
TemporalMask
(
utils
::
GetValueByIndex
<
int64_t
>
(
seed_timestamp
,
seed_offset
),
csc_indices
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment