Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
898af658
Unverified
Commit
898af658
authored
Dec 27, 2023
by
czkkkkkk
Committed by
GitHub
Dec 27, 2023
Browse files
[Graphbolt] Add temporal sampling unittests. (#6795)
parent
1f9ae668
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
257 additions
and
0 deletions
+257
-0
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
...n/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
+257
-0
No files found.
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
View file @
898af658
...
@@ -12,6 +12,8 @@ import dgl.graphbolt as gb
...
@@ -12,6 +12,8 @@ import dgl.graphbolt as gb
import
pytest
import
pytest
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
dgl.graphbolt.base
import
etype_str_to_tuple
from
scipy
import
sparse
as
spsp
from
scipy
import
sparse
as
spsp
from
..
import
gb_test_utils
as
gbt
from
..
import
gb_test_utils
as
gbt
...
@@ -1001,6 +1003,124 @@ def test_sample_neighbors_homo(labor, indptr_dtype, indices_dtype):
...
@@ -1001,6 +1003,124 @@ def test_sample_neighbors_homo(labor, indptr_dtype, indices_dtype):
assert
subgraph
.
original_edge_ids
is
None
assert
subgraph
.
original_edge_ids
is
None
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"gpu"
,
reason
=
"Graph is CPU only at present."
,
)
@
pytest
.
mark
.
parametrize
(
"indptr_dtype"
,
[
torch
.
int32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"indices_dtype"
,
[
torch
.
int32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"replace"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"use_node_timestamp"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"use_edge_timestamp"
,
[
False
,
True
])
def
test_temporal_sample_neighbors_homo
(
indptr_dtype
,
indices_dtype
,
replace
,
use_node_timestamp
,
use_edge_timestamp
):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
total_num_nodes
=
5
total_num_edges
=
12
indptr
=
torch
.
tensor
([
0
,
3
,
5
,
7
,
9
,
12
],
dtype
=
indptr_dtype
)
indices
=
torch
.
tensor
(
[
0
,
1
,
4
,
2
,
3
,
0
,
1
,
1
,
2
,
0
,
3
,
4
],
dtype
=
indices_dtype
)
assert
indptr
[
-
1
]
==
total_num_edges
assert
indptr
[
-
1
]
==
len
(
indices
)
assert
len
(
indptr
)
==
total_num_nodes
+
1
# Construct FusedCSCSamplingGraph.
graph
=
gb
.
fused_csc_sampling_graph
(
indptr
,
indices
)
# Generate subgraph via sample neighbors.
fanouts
=
torch
.
LongTensor
([
2
])
sampler
=
graph
.
temporal_sample_neighbors
seed_list
=
[
1
,
3
,
4
]
seed_timestamp
=
torch
.
randint
(
0
,
100
,
(
len
(
seed_list
),),
dtype
=
torch
.
int64
)
if
use_node_timestamp
:
node_timestamp
=
torch
.
randint
(
0
,
100
,
(
total_num_nodes
,),
dtype
=
torch
.
int64
)
graph
.
node_attributes
=
{
"timestamp"
:
node_timestamp
}
if
use_edge_timestamp
:
edge_timestamp
=
torch
.
randint
(
0
,
100
,
(
total_num_edges
,),
dtype
=
torch
.
int64
)
graph
.
edge_attributes
=
{
"timestamp"
:
edge_timestamp
}
# Sample with nodes in mismatched dtype with graph's indices.
nodes
=
torch
.
tensor
(
seed_list
,
dtype
=
(
torch
.
int64
if
indices_dtype
==
torch
.
int32
else
torch
.
int32
),
)
with
pytest
.
raises
(
AssertionError
,
match
=
re
.
escape
(
"Data type of nodes must be consistent with indices.dtype"
),
):
_
=
sampler
(
nodes
,
seed_timestamp
,
fanouts
,
replace
=
replace
,
node_timestamp_attr_name
=
"timestamp"
if
use_node_timestamp
else
None
,
edge_timestamp_attr_name
=
"timestamp"
if
use_edge_timestamp
else
None
,
)
def
_get_available_neighbors
():
available_neighbors
=
[]
for
i
,
seed
in
enumerate
(
seed_list
):
neighbors
=
[]
start
=
indptr
[
seed
].
item
()
end
=
indptr
[
seed
+
1
].
item
()
for
j
in
range
(
start
,
end
):
neighbor
=
indices
[
j
].
item
()
if
(
use_node_timestamp
and
(
node_timestamp
[
neighbor
]
>
seed_timestamp
[
i
]).
item
()
):
continue
if
(
use_edge_timestamp
and
(
edge_timestamp
[
j
]
>
seed_timestamp
[
i
]).
item
()
):
continue
neighbors
.
append
(
neighbor
)
available_neighbors
.
append
(
neighbors
)
return
available_neighbors
nodes
=
torch
.
tensor
(
seed_list
,
dtype
=
indices_dtype
)
subgraph
,
neighbors_timestamp
=
sampler
(
nodes
,
seed_timestamp
,
fanouts
,
replace
=
replace
,
node_timestamp_attr_name
=
"timestamp"
if
use_node_timestamp
else
None
,
edge_timestamp_attr_name
=
"timestamp"
if
use_edge_timestamp
else
None
,
)
sampled_count
=
torch
.
diff
(
subgraph
.
node_pairs
.
indptr
).
tolist
()
available_neighbors
=
_get_available_neighbors
()
for
i
,
count
in
enumerate
(
sampled_count
):
if
not
replace
:
expect_count
=
min
(
fanouts
[
0
],
len
(
available_neighbors
[
i
]))
else
:
expect_count
=
fanouts
[
0
]
if
len
(
available_neighbors
[
i
])
>
0
else
0
assert
count
==
expect_count
sampled_neighbors
=
torch
.
split
(
subgraph
.
node_pairs
.
indices
,
sampled_count
)
for
i
,
neighbors
in
enumerate
(
sampled_neighbors
):
assert
set
(
neighbors
.
tolist
()).
issubset
(
set
(
available_neighbors
[
i
]))
@
unittest
.
skipIf
(
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"gpu"
,
F
.
_default_context_str
==
"gpu"
,
reason
=
"Graph is CPU only at present."
,
reason
=
"Graph is CPU only at present."
,
...
@@ -1137,6 +1257,143 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
...
@@ -1137,6 +1257,143 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
assert
subgraph
.
original_edge_ids
is
None
assert
subgraph
.
original_edge_ids
is
None
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"gpu"
,
reason
=
"Graph is CPU only at present."
,
)
@
pytest
.
mark
.
parametrize
(
"indptr_dtype"
,
[
torch
.
int32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"indices_dtype"
,
[
torch
.
int32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"replace"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"use_node_timestamp"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"use_edge_timestamp"
,
[
False
,
True
])
def
test_temporal_sample_neighbors_hetero
(
indptr_dtype
,
indices_dtype
,
replace
,
use_node_timestamp
,
use_edge_timestamp
):
"""Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 1 1
1 1 0 0 0
0 1 0 0 0
1 0 0 0 0
"""
# Initialize data.
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
etypes
=
{
"n1:e1:n2"
:
0
,
"n2:e2:n1"
:
1
}
ntypes_to_offset
=
{
"n1"
:
0
,
"n2"
:
2
}
total_num_nodes
=
5
total_num_edges
=
9
indptr
=
torch
.
tensor
([
0
,
2
,
4
,
6
,
7
,
9
],
dtype
=
indptr_dtype
)
indices
=
torch
.
tensor
([
2
,
4
,
2
,
3
,
0
,
1
,
1
,
0
,
1
],
dtype
=
indices_dtype
)
type_per_edge
=
torch
.
LongTensor
([
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
])
node_type_offset
=
torch
.
LongTensor
([
0
,
2
,
5
])
assert
indptr
[
-
1
]
==
total_num_edges
assert
indptr
[
-
1
]
==
len
(
indices
)
# Construct FusedCSCSamplingGraph.
graph
=
gb
.
fused_csc_sampling_graph
(
indptr
,
indices
,
node_type_offset
=
node_type_offset
,
type_per_edge
=
type_per_edge
,
node_type_to_id
=
ntypes
,
edge_type_to_id
=
etypes
,
)
# Generate subgraph via sample neighbors.
fanouts
=
torch
.
LongTensor
([
-
1
,
-
1
])
sampler
=
graph
.
temporal_sample_neighbors
seeds
=
{
"n1"
:
torch
.
tensor
([
0
],
dtype
=
indices_dtype
),
"n2"
:
torch
.
tensor
([
0
],
dtype
=
indices_dtype
),
}
per_etype_destination_nodes
=
{
"n1:e1:n2"
:
torch
.
tensor
([
1
],
dtype
=
indices_dtype
),
"n2:e2:n1"
:
torch
.
tensor
([
0
],
dtype
=
indices_dtype
),
}
seed_timestamp
=
{
"n1"
:
torch
.
randint
(
0
,
100
,
(
1
,),
dtype
=
torch
.
int64
),
"n2"
:
torch
.
randint
(
0
,
100
,
(
1
,),
dtype
=
torch
.
int64
),
}
if
use_node_timestamp
:
node_timestamp
=
torch
.
randint
(
0
,
100
,
(
total_num_nodes
,),
dtype
=
torch
.
int64
)
graph
.
node_attributes
=
{
"timestamp"
:
node_timestamp
}
if
use_edge_timestamp
:
edge_timestamp
=
torch
.
randint
(
0
,
100
,
(
total_num_edges
,),
dtype
=
torch
.
int64
)
graph
.
edge_attributes
=
{
"timestamp"
:
edge_timestamp
}
subgraph
,
neighbors_timestamp
=
sampler
(
seeds
,
seed_timestamp
,
fanouts
,
replace
=
replace
,
node_timestamp_attr_name
=
"timestamp"
if
use_node_timestamp
else
None
,
edge_timestamp_attr_name
=
"timestamp"
if
use_edge_timestamp
else
None
,
)
def
_to_homo
():
ret_seeds
,
ret_timestamps
=
[],
[]
for
ntype
,
nodes
in
seeds
.
items
():
ntype_id
=
ntypes
[
ntype
]
offset
=
node_type_offset
[
ntype_id
]
ret_seeds
.
append
(
nodes
+
offset
)
ret_timestamps
.
append
(
seed_timestamp
[
ntype
])
return
torch
.
cat
(
ret_seeds
),
torch
.
cat
(
ret_timestamps
)
homo_seeds
,
homo_seed_timestamp
=
_to_homo
()
def
_get_available_neighbors
():
available_neighbors
=
[]
for
i
,
seed
in
enumerate
(
homo_seeds
):
neighbors
=
[]
start
=
indptr
[
seed
].
item
()
end
=
indptr
[
seed
+
1
].
item
()
for
j
in
range
(
start
,
end
):
neighbor
=
indices
[
j
].
item
()
if
(
use_node_timestamp
and
(
node_timestamp
[
neighbor
]
>
homo_seed_timestamp
[
i
]
).
item
()
):
continue
if
(
use_edge_timestamp
and
(
edge_timestamp
[
j
]
>
homo_seed_timestamp
[
i
]).
item
()
):
continue
neighbors
.
append
(
neighbor
)
available_neighbors
.
append
(
neighbors
)
return
available_neighbors
available_neighbors
=
_get_available_neighbors
()
sampled_count
=
[
0
]
*
homo_seeds
.
numel
()
sampled_neighbors
=
[[]
for
_
in
range
(
homo_seeds
.
numel
())]
for
etype
,
csc
in
subgraph
.
node_pairs
.
items
():
stype
,
_
,
_
=
etype_str_to_tuple
(
etype
)
ntype_offset
=
ntypes_to_offset
[
stype
]
dest_nodes
=
per_etype_destination_nodes
[
etype
]
for
i
in
range
(
dest_nodes
.
numel
()):
l
=
csc
.
indptr
[
i
]
r
=
csc
.
indptr
[
i
+
1
]
seed_offset
=
dest_nodes
[
i
].
item
()
sampled_neighbors
[
seed_offset
].
extend
(
(
csc
.
indices
[
l
:
r
]
+
ntype_offset
).
tolist
()
)
sampled_count
[
seed_offset
]
+=
r
-
l
for
i
,
count
in
enumerate
(
sampled_count
):
assert
count
==
len
(
available_neighbors
[
i
])
assert
set
(
sampled_neighbors
[
i
]).
issubset
(
set
(
available_neighbors
[
i
]))
@
unittest
.
skipIf
(
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"gpu"
,
F
.
_default_context_str
==
"gpu"
,
reason
=
"Graph is CPU only at present."
,
reason
=
"Graph is CPU only at present."
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment