Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
57281e9f
Unverified
Commit
57281e9f
authored
Jan 03, 2024
by
Rhett Ying
Committed by
GitHub
Jan 03, 2024
Browse files
[GraphBolt] sample with unknown etype (#6888)
parent
e9deff7d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
89 additions
and
0 deletions
+89
-0
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+89
-0
No files found.
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
57281e9f
...
@@ -326,6 +326,95 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type):
...
@@ -326,6 +326,95 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type):
assert
len
(
list
(
datapipe
))
==
5
assert
len
(
list
(
datapipe
))
==
5
@
pytest
.
mark
.
parametrize
(
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_Link_Hetero_Unknown_Etype
(
sampler_type
):
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
first_items
=
torch
.
LongTensor
([[
0
,
0
,
1
,
1
],
[
0
,
2
,
0
,
1
]]).
T
first_names
=
"node_pairs"
second_items
=
torch
.
LongTensor
([[
0
,
0
,
1
,
1
,
2
,
2
],
[
0
,
1
,
1
,
0
,
0
,
1
]]).
T
second_names
=
"node_pairs"
if
sampler_type
==
SamplerType
.
Temporal
:
graph
.
node_attributes
=
{
"timestamp"
:
torch
.
arange
(
graph
.
csc_indptr
.
numel
()
-
1
).
to
(
F
.
ctx
())
}
graph
.
edge_attributes
=
{
"timestamp"
:
torch
.
arange
(
graph
.
indices
.
numel
()).
to
(
F
.
ctx
())
}
first_items
=
(
first_items
,
torch
.
randint
(
0
,
10
,
(
4
,)))
first_names
=
(
first_names
,
"timestamp"
)
second_items
=
(
second_items
,
torch
.
randint
(
0
,
10
,
(
6
,)))
second_names
=
(
second_names
,
"timestamp"
)
# "e11" and "e22" are not valid edge types.
itemset
=
gb
.
ItemSetDict
(
{
"n1:e11:n2"
:
gb
.
ItemSet
(
first_items
,
names
=
first_names
,
),
"n2:e22:n1"
:
gb
.
ItemSet
(
second_items
,
names
=
second_names
,
),
}
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
sampler
=
_get_sampler
(
sampler_type
)
datapipe
=
sampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
assert
len
(
list
(
datapipe
))
==
5
@
pytest
.
mark
.
parametrize
(
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype
(
sampler_type
):
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
first_items
=
torch
.
LongTensor
([[
0
,
0
,
1
,
1
],
[
0
,
2
,
0
,
1
]]).
T
first_names
=
"node_pairs"
second_items
=
torch
.
LongTensor
([[
0
,
0
,
1
,
1
,
2
,
2
],
[
0
,
1
,
1
,
0
,
0
,
1
]]).
T
second_names
=
"node_pairs"
if
sampler_type
==
SamplerType
.
Temporal
:
graph
.
node_attributes
=
{
"timestamp"
:
torch
.
arange
(
graph
.
csc_indptr
.
numel
()
-
1
).
to
(
F
.
ctx
())
}
graph
.
edge_attributes
=
{
"timestamp"
:
torch
.
arange
(
graph
.
indices
.
numel
()).
to
(
F
.
ctx
())
}
first_items
=
(
first_items
,
torch
.
randint
(
0
,
10
,
(
4
,)))
first_names
=
(
first_names
,
"timestamp"
)
second_items
=
(
second_items
,
torch
.
randint
(
0
,
10
,
(
6
,)))
second_names
=
(
second_names
,
"timestamp"
)
# "e11" and "e22" are not valid edge types.
itemset
=
gb
.
ItemSetDict
(
{
"n1:e11:n2"
:
gb
.
ItemSet
(
first_items
,
names
=
first_names
,
),
"n2:e22:n1"
:
gb
.
ItemSet
(
second_items
,
names
=
second_names
,
),
}
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
datapipe
=
gb
.
UniformNegativeSampler
(
datapipe
,
graph
,
1
)
sampler
=
_get_sampler
(
sampler_type
)
datapipe
=
sampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
assert
len
(
list
(
datapipe
))
==
5
@
unittest
.
skipIf
(
@
unittest
.
skipIf
(
F
.
_default_context_str
!=
"cpu"
,
F
.
_default_context_str
!=
"cpu"
,
reason
=
"Sampling with replacement not yet supported on GPU."
,
reason
=
"Sampling with replacement not yet supported on GPU."
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment