Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
61504ec5
"...transforms/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "72dcc1705d290a2b6084d0fed97e6f52c670ae79"
Unverified
Commit
61504ec5
authored
Jan 08, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Jan 08, 2024
Browse files
[GraphBolt] Extend temporal sampling to labor (#6816)
parent
333ce36c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
87 additions
and
54 deletions
+87
-54
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
+5
-2
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+81
-52
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+1
-0
No files found.
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
View file @
61504ec5
...
@@ -335,6 +335,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -335,6 +335,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @param replace Boolean indicating whether the sample is preformed with or
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple times.
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* Otherwise, each value can be selected only once.
* @param layer Boolean indicating whether neighbors should be sampled in a
* layer sampling fashion. Uses the LABOR-0 algorithm to increase overlap of
* sampled edges, see arXiv:2210.13339.
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required.
* typically used when edge features are required.
* @param probs_name An optional string specifying the name of an edge
* @param probs_name An optional string specifying the name of an edge
...
@@ -351,8 +354,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -351,8 +354,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
TemporalSampleNeighbors
(
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
TemporalSampleNeighbors
(
const
torch
::
Tensor
&
input_nodes
,
const
torch
::
Tensor
&
input_nodes
,
const
torch
::
Tensor
&
input_nodes_timestamp
,
const
torch
::
Tensor
&
input_nodes_timestamp
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
return_eids
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
torch
::
optional
<
std
::
string
>
probs_name
,
bool
return_eids
,
torch
::
optional
<
std
::
string
>
probs_name
,
torch
::
optional
<
std
::
string
>
node_timestamp_attr_name
,
torch
::
optional
<
std
::
string
>
node_timestamp_attr_name
,
torch
::
optional
<
std
::
string
>
edge_timestamp_attr_name
)
const
;
torch
::
optional
<
std
::
string
>
edge_timestamp_attr_name
)
const
;
...
...
graphbolt/src/fused_csc_sampling_graph.cc
View file @
61504ec5
...
@@ -437,6 +437,7 @@ auto GetPickFn(
...
@@ -437,6 +437,7 @@ auto GetPickFn(
};
};
}
}
template
<
SamplerType
S
>
auto
GetTemporalPickFn
(
auto
GetTemporalPickFn
(
torch
::
Tensor
seed_timestamp
,
torch
::
Tensor
csc_indices
,
torch
::
Tensor
seed_timestamp
,
torch
::
Tensor
csc_indices
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
...
@@ -444,30 +445,31 @@ auto GetTemporalPickFn(
...
@@ -444,30 +445,31 @@ auto GetTemporalPickFn(
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_timestamp
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_timestamp
,
const
torch
::
optional
<
torch
::
Tensor
>&
edge_timestamp
)
{
const
torch
::
optional
<
torch
::
Tensor
>&
edge_timestamp
,
SamplerArgs
<
S
>
args
)
{
return
[
&
seed_timestamp
,
&
csc_indices
,
&
fanouts
,
replace
,
&
options
,
return
&
type_per_edge
,
&
probs_or_mask
,
&
node_timestamp
,
&
edge_timestamp
](
[
&
seed_timestamp
,
&
csc_indices
,
&
fanouts
,
replace
,
&
options
,
int64_t
seed_offset
,
int64_t
offset
,
int64_t
num_neighbors
,
&
type_per_edge
,
&
probs_or_mask
,
&
node_timestamp
,
&
edge_timestamp
,
args
](
auto
picked_data_ptr
)
{
int64_t
seed_offset
,
int64_t
offset
,
int64_t
num_neighbors
,
// If fanouts.size() > 1, perform sampling for each edge type of each
auto
picked_data_ptr
)
{
// node; otherwise just sample once for each node with no regard of edge
// If fanouts.size() > 1, perform sampling for each edge type of each
// types.
// node; otherwise just sample once for each node with no regard of edge
if
(
fanouts
.
size
()
>
1
)
{
// types.
return
TemporalPickByEtype
(
if
(
fanouts
.
size
()
>
1
)
{
seed_timestamp
,
csc_indices
,
seed_offset
,
offset
,
num_neighbors
,
return
TemporalPickByEtype
(
fanouts
,
replace
,
options
,
type_per_edge
.
value
(),
probs_or_mask
,
seed_timestamp
,
csc_indices
,
seed_offset
,
offset
,
num_neighbors
,
node_timestamp
,
edge_timestamp
,
picked_data_ptr
);
fanouts
,
replace
,
options
,
type_per_edge
.
value
(),
probs_or_mask
,
}
else
{
node_timestamp
,
edge_timestamp
,
args
,
picked_data_ptr
);
int64_t
num_sampled
=
TemporalPick
(
}
else
{
seed_timestamp
,
csc_indices
,
seed_offset
,
offset
,
num_neighbors
,
int64_t
num_sampled
=
TemporalPick
(
fanouts
[
0
],
replace
,
options
,
probs_or_mask
,
node_timestamp
,
seed_timestamp
,
csc_indices
,
seed_offset
,
offset
,
num_neighbors
,
edge_timestamp
,
picked_data_ptr
);
fanouts
[
0
],
replace
,
options
,
probs_or_mask
,
node_timestamp
,
if
(
type_per_edge
)
{
edge_timestamp
,
args
,
picked_data_ptr
);
std
::
sort
(
picked_data_ptr
,
picked_data_ptr
+
num_sampled
);
if
(
type_per_edge
.
has_value
())
{
}
std
::
sort
(
picked_data_ptr
,
picked_data_ptr
+
num_sampled
);
return
num_sampled
;
}
}
return
num_sampled
;
};
}
};
}
}
template
<
typename
NumPickFn
,
typename
PickFn
>
template
<
typename
NumPickFn
,
typename
PickFn
>
...
@@ -664,8 +666,8 @@ c10::intrusive_ptr<FusedSampledSubgraph>
...
@@ -664,8 +666,8 @@ c10::intrusive_ptr<FusedSampledSubgraph>
FusedCSCSamplingGraph
::
TemporalSampleNeighbors
(
FusedCSCSamplingGraph
::
TemporalSampleNeighbors
(
const
torch
::
Tensor
&
input_nodes
,
const
torch
::
Tensor
&
input_nodes
,
const
torch
::
Tensor
&
input_nodes_timestamp
,
const
torch
::
Tensor
&
input_nodes_timestamp
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
return_eids
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
torch
::
optional
<
std
::
string
>
probs_name
,
bool
return_eids
,
torch
::
optional
<
std
::
string
>
probs_name
,
torch
::
optional
<
std
::
string
>
node_timestamp_attr_name
,
torch
::
optional
<
std
::
string
>
node_timestamp_attr_name
,
torch
::
optional
<
std
::
string
>
edge_timestamp_attr_name
)
const
{
torch
::
optional
<
std
::
string
>
edge_timestamp_attr_name
)
const
{
// 1. Get probs_or_mask.
// 1. Get probs_or_mask.
...
@@ -684,15 +686,31 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
...
@@ -684,15 +686,31 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
// 3. Get the timestamp attribute for edges of the graph
// 3. Get the timestamp attribute for edges of the graph
auto
edge_timestamp
=
this
->
EdgeAttribute
(
edge_timestamp_attr_name
);
auto
edge_timestamp
=
this
->
EdgeAttribute
(
edge_timestamp_attr_name
);
// 4. Call SampleNeighborsImpl
// 4. Call SampleNeighborsImpl
return
SampleNeighborsImpl
(
if
(
layer
)
{
input_nodes
,
return_eids
,
const
int64_t
random_seed
=
RandomEngine
::
ThreadLocal
()
->
RandInt
(
GetTemporalNumPickFn
(
static_cast
<
int64_t
>
(
0
),
std
::
numeric_limits
<
int64_t
>::
max
());
input_nodes_timestamp
,
this
->
indices_
,
fanouts
,
replace
,
SamplerArgs
<
SamplerType
::
LABOR
>
args
{
indices_
,
random_seed
,
NumNodes
()};
type_per_edge_
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
),
return
SampleNeighborsImpl
(
GetTemporalPickFn
(
input_nodes
,
return_eids
,
input_nodes_timestamp
,
this
->
indices_
,
fanouts
,
replace
,
GetTemporalNumPickFn
(
indptr_
.
options
(),
type_per_edge_
,
probs_or_mask
,
node_timestamp
,
input_nodes_timestamp
,
this
->
indices_
,
fanouts
,
replace
,
edge_timestamp
));
type_per_edge_
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
),
GetTemporalPickFn
(
input_nodes_timestamp
,
this
->
indices_
,
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
,
args
));
}
else
{
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
;
return
SampleNeighborsImpl
(
input_nodes
,
return_eids
,
GetTemporalNumPickFn
(
input_nodes_timestamp
,
this
->
indices_
,
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
),
GetTemporalPickFn
(
input_nodes_timestamp
,
this
->
indices_
,
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
,
args
));
}
}
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
...
@@ -1130,11 +1148,12 @@ static torch::Tensor NonUniformPickOp(
...
@@ -1130,11 +1148,12 @@ static torch::Tensor NonUniformPickOp(
template
<
typename
PickedType
>
template
<
typename
PickedType
>
inline
int64_t
NonUniformPick
(
inline
int64_t
NonUniformPick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
Tensor
&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
PickedType
*
picked_data_ptr
)
{
PickedType
*
picked_data_ptr
)
{
auto
local_probs
=
auto
local_probs
=
probs_or_mask
.
value
().
slice
(
0
,
offset
,
offset
+
num_neighbors
);
probs_or_mask
.
size
(
0
)
>
num_neighbors
?
probs_or_mask
.
slice
(
0
,
offset
,
offset
+
num_neighbors
)
:
probs_or_mask
;
auto
picked_indices
=
NonUniformPickOp
(
local_probs
,
fanout
,
replace
);
auto
picked_indices
=
NonUniformPickOp
(
local_probs
,
fanout
,
replace
);
auto
picked_indices_ptr
=
picked_indices
.
data_ptr
<
int64_t
>
();
auto
picked_indices_ptr
=
picked_indices
.
data_ptr
<
int64_t
>
();
for
(
int
i
=
0
;
i
<
picked_indices
.
numel
();
++
i
)
{
for
(
int
i
=
0
;
i
<
picked_indices
.
numel
();
++
i
)
{
...
@@ -1152,7 +1171,7 @@ int64_t Pick(
...
@@ -1152,7 +1171,7 @@ int64_t Pick(
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
,
PickedType
*
picked_data_ptr
)
{
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
,
PickedType
*
picked_data_ptr
)
{
if
(
probs_or_mask
.
has_value
())
{
if
(
probs_or_mask
.
has_value
())
{
return
NonUniformPick
(
return
NonUniformPick
(
offset
,
num_neighbors
,
fanout
,
replace
,
options
,
probs_or_mask
,
offset
,
num_neighbors
,
fanout
,
replace
,
options
,
probs_or_mask
.
value
()
,
picked_data_ptr
);
picked_data_ptr
);
}
else
{
}
else
{
return
UniformPick
(
return
UniformPick
(
...
@@ -1160,14 +1179,14 @@ int64_t Pick(
...
@@ -1160,14 +1179,14 @@ int64_t Pick(
}
}
}
}
template
<
typename
PickedType
>
template
<
SamplerType
S
,
typename
PickedType
>
int64_t
TemporalPick
(
int64_t
TemporalPick
(
torch
::
Tensor
seed_timestamp
,
torch
::
Tensor
csc_indices
,
torch
::
Tensor
seed_timestamp
,
torch
::
Tensor
csc_indices
,
int64_t
seed_offset
,
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
int64_t
seed_offset
,
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_timestamp
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_timestamp
,
const
torch
::
optional
<
torch
::
Tensor
>&
edge_timestamp
,
const
torch
::
optional
<
torch
::
Tensor
>&
edge_timestamp
,
SamplerArgs
<
S
>
args
,
PickedType
*
picked_data_ptr
)
{
PickedType
*
picked_data_ptr
)
{
auto
mask
=
TemporalMask
(
auto
mask
=
TemporalMask
(
utils
::
GetValueByIndex
<
int64_t
>
(
seed_timestamp
,
seed_offset
),
csc_indices
,
utils
::
GetValueByIndex
<
int64_t
>
(
seed_timestamp
,
seed_offset
),
csc_indices
,
...
@@ -1180,13 +1199,20 @@ int64_t TemporalPick(
...
@@ -1180,13 +1199,20 @@ int64_t TemporalPick(
}
else
{
}
else
{
masked_prob
=
mask
.
to
(
torch
::
kFloat32
);
masked_prob
=
mask
.
to
(
torch
::
kFloat32
);
}
}
auto
picked_indices
=
NonUniformPickOp
(
masked_prob
,
fanout
,
replace
);
if
constexpr
(
S
==
SamplerType
::
NEIGHBOR
)
{
auto
picked_indices_ptr
=
picked_indices
.
data_ptr
<
int64_t
>
();
auto
picked_indices
=
NonUniformPickOp
(
masked_prob
,
fanout
,
replace
);
for
(
int
i
=
0
;
i
<
picked_indices
.
numel
();
++
i
)
{
auto
picked_indices_ptr
=
picked_indices
.
data_ptr
<
int64_t
>
();
picked_data_ptr
[
i
]
=
for
(
int
i
=
0
;
i
<
picked_indices
.
numel
();
++
i
)
{
static_cast
<
PickedType
>
(
picked_indices_ptr
[
i
])
+
offset
;
picked_data_ptr
[
i
]
=
static_cast
<
PickedType
>
(
picked_indices_ptr
[
i
])
+
offset
;
}
return
picked_indices
.
numel
();
}
if
constexpr
(
S
==
SamplerType
::
LABOR
)
{
return
Pick
(
offset
,
num_neighbors
,
fanout
,
replace
,
options
,
masked_prob
,
args
,
picked_data_ptr
);
}
}
return
picked_indices
.
numel
();
}
}
template
<
SamplerType
S
,
typename
PickedType
>
template
<
SamplerType
S
,
typename
PickedType
>
...
@@ -1226,7 +1252,7 @@ int64_t PickByEtype(
...
@@ -1226,7 +1252,7 @@ int64_t PickByEtype(
return
pick_offset
;
return
pick_offset
;
}
}
template
<
typename
PickedType
>
template
<
SamplerType
S
,
typename
PickedType
>
int64_t
TemporalPickByEtype
(
int64_t
TemporalPickByEtype
(
torch
::
Tensor
seed_timestamp
,
torch
::
Tensor
csc_indices
,
torch
::
Tensor
seed_timestamp
,
torch
::
Tensor
csc_indices
,
int64_t
seed_offset
,
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
seed_offset
,
int64_t
offset
,
int64_t
num_neighbors
,
...
@@ -1234,7 +1260,7 @@ int64_t TemporalPickByEtype(
...
@@ -1234,7 +1260,7 @@ int64_t TemporalPickByEtype(
const
torch
::
TensorOptions
&
options
,
const
torch
::
Tensor
&
type_per_edge
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
Tensor
&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_timestamp
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_timestamp
,
const
torch
::
optional
<
torch
::
Tensor
>&
edge_timestamp
,
const
torch
::
optional
<
torch
::
Tensor
>&
edge_timestamp
,
SamplerArgs
<
S
>
args
,
PickedType
*
picked_data_ptr
)
{
PickedType
*
picked_data_ptr
)
{
int64_t
etype_begin
=
offset
;
int64_t
etype_begin
=
offset
;
int64_t
etype_end
=
offset
;
int64_t
etype_end
=
offset
;
...
@@ -1258,7 +1284,7 @@ int64_t TemporalPickByEtype(
...
@@ -1258,7 +1284,7 @@ int64_t TemporalPickByEtype(
int64_t
picked_count
=
TemporalPick
(
int64_t
picked_count
=
TemporalPick
(
seed_timestamp
,
csc_indices
,
seed_offset
,
etype_begin
,
seed_timestamp
,
csc_indices
,
seed_offset
,
etype_begin
,
etype_end
-
etype_begin
,
fanout
,
replace
,
options
,
etype_end
-
etype_begin
,
fanout
,
replace
,
options
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
,
args
,
picked_data_ptr
+
pick_offset
);
picked_data_ptr
+
pick_offset
);
pick_offset
+=
picked_count
;
pick_offset
+=
picked_count
;
}
}
...
@@ -1278,8 +1304,8 @@ int64_t Pick(
...
@@ -1278,8 +1304,8 @@ int64_t Pick(
if
(
probs_or_mask
.
has_value
())
{
if
(
probs_or_mask
.
has_value
())
{
if
(
fanout
<
0
)
{
if
(
fanout
<
0
)
{
return
NonUniformPick
(
return
NonUniformPick
(
offset
,
num_neighbors
,
fanout
,
replace
,
options
,
probs_or_mask
,
offset
,
num_neighbors
,
fanout
,
replace
,
options
,
picked_data_ptr
);
probs_or_mask
.
value
(),
picked_data_ptr
);
}
else
{
}
else
{
int64_t
picked_count
;
int64_t
picked_count
;
AT_DISPATCH_FLOATING_TYPES
(
AT_DISPATCH_FLOATING_TYPES
(
...
@@ -1365,6 +1391,9 @@ inline int64_t LaborPick(
...
@@ -1365,6 +1391,9 @@ inline int64_t LaborPick(
const
ProbsType
*
local_probs_data
=
const
ProbsType
*
local_probs_data
=
NonUniform
?
probs_or_mask
.
value
().
data_ptr
<
ProbsType
>
()
+
offset
NonUniform
?
probs_or_mask
.
value
().
data_ptr
<
ProbsType
>
()
+
offset
:
nullptr
;
:
nullptr
;
if
(
NonUniform
&&
probs_or_mask
.
value
().
size
(
0
)
<=
num_neighbors
)
{
local_probs_data
-=
offset
;
}
AT_DISPATCH_INTEGRAL_TYPES
(
AT_DISPATCH_INTEGRAL_TYPES
(
args
.
indices
.
scalar_type
(),
"LaborPickMain"
,
([
&
]
{
args
.
indices
.
scalar_type
(),
"LaborPickMain"
,
([
&
]
{
const
scalar_t
*
local_indices_data
=
const
scalar_t
*
local_indices_data
=
...
...
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
61504ec5
...
@@ -860,6 +860,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -860,6 +860,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
input_nodes_timestamp
,
input_nodes_timestamp
,
fanouts
.
tolist
(),
fanouts
.
tolist
(),
replace
,
replace
,
False
,
has_original_eids
,
has_original_eids
,
probs_name
,
probs_name
,
node_timestamp_attr_name
,
node_timestamp_attr_name
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment