Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
3742d5ff
Unverified
Commit
3742d5ff
authored
Sep 04, 2023
by
Rhett Ying
Committed by
GitHub
Sep 04, 2023
Browse files
[GraphBolt] rename data attributes in MiniBatch (#6274)
parent
72683697
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
131 additions
and
131 deletions
+131
-131
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+4
-4
python/dgl/graphbolt/item_sampler.py
python/dgl/graphbolt/item_sampler.py
+3
-3
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+24
-24
python/dgl/graphbolt/negative_sampler.py
python/dgl/graphbolt/negative_sampler.py
+20
-20
python/dgl/graphbolt/subgraph_sampler.py
python/dgl/graphbolt/subgraph_sampler.py
+28
-28
tests/python/pytorch/graphbolt/gb_test_utils.py
tests/python/pytorch/graphbolt/gb_test_utils.py
+2
-2
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
+11
-11
tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py
tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py
+38
-38
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+1
-1
No files found.
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
3742d5ff
...
...
@@ -54,7 +54,7 @@ class NeighborSampler(SubgraphSampler):
>>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper
>>> def minibatch_link_collator(data):
... minibatch = gb.MiniBatch(node_pair=data)
... minibatch = gb.MiniBatch(node_pair
s
=data)
... return minibatch
...
>>> from dgl import graphbolt as gb
...
...
@@ -76,7 +76,7 @@ class NeighborSampler(SubgraphSampler):
>>> subgraph_sampler = gb.NeighborSampler(
...neg_sampler, graph, fanouts)
>>> for data in subgraph_sampler:
... print(data.compacted_node_pair)
... print(data.compacted_node_pair
s
)
... print(len(data.sampled_subgraphs))
(tensor([0, 0, 0]), tensor([1, 0, 2]))
3
...
...
@@ -166,7 +166,7 @@ class LayerNeighborSampler(NeighborSampler):
>>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper
>>> def minibatch_link_collator(data):
... minibatch = gb.MiniBatch(node_pair=data)
... minibatch = gb.MiniBatch(node_pair
s
=data)
... return minibatch
...
>>> from dgl import graphbolt as gb
...
...
@@ -188,7 +188,7 @@ class LayerNeighborSampler(NeighborSampler):
>>> subgraph_sampler = gb.LayerNeighborSampler(
...neg_sampler, graph, fanouts)
>>> for data in subgraph_sampler:
... print(data.compacted_node_pair)
... print(data.compacted_node_pair
s
)
... print(len(data.sampled_subgraphs))
(tensor([0, 0, 0]), tensor([1, 0, 2]))
3
...
...
python/dgl/graphbolt/item_sampler.py
View file @
3742d5ff
...
...
@@ -17,9 +17,9 @@ __all__ = ["ItemSampler"]
class
ItemSampler
(
IterDataPipe
):
"""Item Sampler.
Creates item subset of data which could be node
/edge
IDs, node pairs with
or
without labels,
head/tail/negative_tails, DGLGraphs and heterogeneou
s
counterparts.
Creates item subset of data which could be node IDs, node pairs with
or
without labels,
node pairs with negative sources/destinations, DGLGraph
s
and heterogeneous
counterparts.
Note: This class `ItemSampler` is not decorated with
`torchdata.datapipes.functional_datapipe` on purpose. This indicates it
...
...
python/dgl/graphbolt/minibatch.py
View file @
3742d5ff
...
...
@@ -55,74 +55,74 @@ class MiniBatch:
value should be corresponding heterogeneous node id.
"""
seed_node
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
seed_node
s
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""
Representation of seed nodes used for sampling in the graph.
- If `seed_node` is a tensor: It indicates the graph is homogeneous.
- If `seed_node` is a dictionary: The keys should be node type and the
- If `seed_node
s
` is a tensor: It indicates the graph is homogeneous.
- If `seed_node
s
` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node ids.
"""
label
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
label
s
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""
L
abels associated with seed nodes in the graph.
- If `label` is a tensor: It indicates the graph is homogeneous. The value
should be corresponding labels to given 'seed_node' or 'node_pair'.
- If `label` is a dictionary: The keys should be node or edge type and the
value should be corresponding labels to given 'seed_node' or 'node_pair'.
l
abels
s
associated with seed nodes in the graph.
- If `label
s
` is a tensor: It indicates the graph is homogeneous. The value
should be corresponding labels
s
to given 'seed_node
s
' or 'node_pair
s
'.
- If `label
s
` is a dictionary: The keys should be node or edge type and the
value should be corresponding labels
s
to given 'seed_node
s
' or 'node_pair
s
'.
"""
node_pair
:
Union
[
node_pair
s
:
Union
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
]
=
None
"""
Representation of seed node pairs utilized in link prediction tasks.
- If `node_pair` is a tuple: It indicates a homogeneous graph where each
- If `node_pair
s
` is a tuple: It indicates a homogeneous graph where each
tuple contains two tensors representing source-destination node pairs.
- If `node_pair` is a dictionary: The keys should be edge type, and the
- If `node_pair
s
` is a dictionary: The keys should be edge type, and the
value should be a tuple of tensors representing node pairs of the given
type.
"""
negative_
head
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
negative_
srcs
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""
Representation of negative samples for the head nodes in the link
prediction task.
- If `negative_
head
` is a tensor: It indicates a homogeneous graph.
- If `negative_
head
` is a dictionary: The key should be edge type, and the
- If `negative_
srcs
` is a tensor: It indicates a homogeneous graph.
- If `negative_
srcs
` is a dictionary: The key should be edge type, and the
value should correspond to the negative samples for head nodes of the
given type.
"""
negative_
tail
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
negative_
dsts
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""
Representation of negative samples for the tail nodes in the link
prediction task.
- If `negative_
tail
` is a tensor: It indicates a homogeneous graph.
- If `negative_
tail
` is a dictionary: The key should be edge type, and the
- If `negative_
dsts
` is a tensor: It indicates a homogeneous graph.
- If `negative_
dsts
` is a dictionary: The key should be edge type, and the
value should correspond to the negative samples for head nodes of the
given type.
"""
compacted_node_pair
:
Union
[
compacted_node_pair
s
:
Union
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
]
=
None
"""
Representation of compacted node pairs corresponding to 'node_pair', where
Representation of compacted node pairs corresponding to 'node_pair
s
', where
all node ids inside are compacted.
"""
compacted_negative_
head
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
compacted_negative_
srcs
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""
Representation of compacted nodes corresponding to 'negative_
head
', where
Representation of compacted nodes corresponding to 'negative_
srcs
', where
all node ids inside are compacted.
"""
compacted_negative_
tail
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
compacted_negative_
dsts
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""
Representation of compacted nodes corresponding to 'negative_
tail
', where
Representation of compacted nodes corresponding to 'negative_
dsts
', where
all node ids inside are compacted.
"""
...
...
python/dgl/graphbolt/negative_sampler.py
View file @
3742d5ff
...
...
@@ -44,9 +44,9 @@ class NegativeSampler(Mapper):
Parameters
----------
minibatch : MiniBatch
An instance of 'MiniBatch' class requires the 'node_pair' field.
An instance of 'MiniBatch' class requires the 'node_pair
s
' field.
This function is responsible for generating negative edges
corresponding to the positive edges defined by the 'node_pair'. In
corresponding to the positive edges defined by the 'node_pair
s
'. In
cases where negative edges already exist, this function will
overwrite them.
...
...
@@ -56,21 +56,21 @@ class NegativeSampler(Mapper):
An instance of 'MiniBatch' encompasses both positive and negative
samples.
"""
node_pairs
=
minibatch
.
node_pair
node_pairs
=
minibatch
.
node_pair
s
assert
node_pairs
is
not
None
if
isinstance
(
node_pairs
,
Mapping
):
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
INDEPENDENT
:
minibatch
.
label
=
{}
minibatch
.
label
s
=
{}
else
:
minibatch
.
negative_
head
,
minibatch
.
negative_
tail
=
{},
{}
minibatch
.
negative_
srcs
,
minibatch
.
negative_
dsts
=
{},
{}
for
etype
,
pos_pairs
in
node_pairs
.
items
():
self
.
_collate
(
minibatch
,
self
.
_sample_with_etype
(
pos_pairs
,
etype
),
etype
)
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
HEAD_CONDITIONED
:
minibatch
.
negative_
tail
=
None
minibatch
.
negative_
dsts
=
None
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
TAIL_CONDITIONED
:
minibatch
.
negative_
head
=
None
minibatch
.
negative_
srcs
=
None
else
:
self
.
_collate
(
minibatch
,
self
.
_sample_with_etype
(
node_pairs
))
return
minibatch
...
...
@@ -111,23 +111,23 @@ class NegativeSampler(Mapper):
Canonical edge type.
"""
pos_src
,
pos_dst
=
(
minibatch
.
node_pair
[
etype
]
minibatch
.
node_pair
s
[
etype
]
if
etype
is
not
None
else
minibatch
.
node_pair
else
minibatch
.
node_pair
s
)
neg_src
,
neg_dst
=
neg_pairs
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
INDEPENDENT
:
pos_label
=
torch
.
ones_like
(
pos_src
)
neg_label
=
torch
.
zeros_like
(
neg_src
)
pos_label
s
=
torch
.
ones_like
(
pos_src
)
neg_label
s
=
torch
.
zeros_like
(
neg_src
)
src
=
torch
.
cat
([
pos_src
,
neg_src
])
dst
=
torch
.
cat
([
pos_dst
,
neg_dst
])
label
=
torch
.
cat
([
pos_label
,
neg_label
])
label
s
=
torch
.
cat
([
pos_label
s
,
neg_label
s
])
if
etype
is
not
None
:
minibatch
.
node_pair
[
etype
]
=
(
src
,
dst
)
minibatch
.
label
[
etype
]
=
label
minibatch
.
node_pair
s
[
etype
]
=
(
src
,
dst
)
minibatch
.
label
s
[
etype
]
=
label
s
else
:
minibatch
.
node_pair
=
(
src
,
dst
)
minibatch
.
label
=
label
minibatch
.
node_pair
s
=
(
src
,
dst
)
minibatch
.
label
s
=
label
s
else
:
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
CONDITIONED
:
neg_src
=
neg_src
.
view
(
-
1
,
self
.
negative_ratio
)
...
...
@@ -147,8 +147,8 @@ class NegativeSampler(Mapper):
f
"Unsupported output format
{
self
.
output_format
}
."
)
if
etype
is
not
None
:
minibatch
.
negative_
head
[
etype
]
=
neg_src
minibatch
.
negative_
tail
[
etype
]
=
neg_dst
minibatch
.
negative_
srcs
[
etype
]
=
neg_src
minibatch
.
negative_
dsts
[
etype
]
=
neg_dst
else
:
minibatch
.
negative_
head
=
neg_src
minibatch
.
negative_
tail
=
neg_dst
minibatch
.
negative_
srcs
=
neg_src
minibatch
.
negative_
dsts
=
neg_dst
python/dgl/graphbolt/subgraph_sampler.py
View file @
3742d5ff
...
...
@@ -28,19 +28,19 @@ class SubgraphSampler(Mapper):
super
().
__init__
(
datapipe
,
self
.
_sample
)
def
_sample
(
self
,
minibatch
):
if
minibatch
.
node_pair
is
not
None
:
if
minibatch
.
node_pair
s
is
not
None
:
(
seeds
,
minibatch
.
compacted_node_pair
,
minibatch
.
compacted_negative_
head
,
minibatch
.
compacted_negative_
tail
,
)
=
self
.
_node_pair_preprocess
(
minibatch
)
elif
minibatch
.
seed_node
is
not
None
:
seeds
=
minibatch
.
seed_node
minibatch
.
compacted_node_pair
s
,
minibatch
.
compacted_negative_
srcs
,
minibatch
.
compacted_negative_
dsts
,
)
=
self
.
_node_pair
s
_preprocess
(
minibatch
)
elif
minibatch
.
seed_node
s
is
not
None
:
seeds
=
minibatch
.
seed_node
s
else
:
raise
ValueError
(
f
"Invalid minibatch
{
minibatch
}
: Either 'node_pair' or
\
'seed_node' should have a value."
f
"Invalid minibatch
{
minibatch
}
: Either 'node_pair
s
' or
\
'seed_node
s
' should have a value."
)
(
minibatch
.
input_nodes
,
...
...
@@ -48,16 +48,16 @@ class SubgraphSampler(Mapper):
)
=
self
.
_sample_subgraphs
(
seeds
)
return
minibatch
def
_node_pair_preprocess
(
self
,
minibatch
):
node_pair
=
minibatch
.
node_pair
neg_src
,
neg_dst
=
minibatch
.
negative_
head
,
minibatch
.
negative_
tail
def
_node_pair
s
_preprocess
(
self
,
minibatch
):
node_pair
s
=
minibatch
.
node_pair
s
neg_src
,
neg_dst
=
minibatch
.
negative_
srcs
,
minibatch
.
negative_
dsts
has_neg_src
=
neg_src
is
not
None
has_neg_dst
=
neg_dst
is
not
None
is_heterogeneous
=
isinstance
(
node_pair
,
Dict
)
is_heterogeneous
=
isinstance
(
node_pair
s
,
Dict
)
if
is_heterogeneous
:
# Collect nodes from all types of input.
nodes
=
defaultdict
(
list
)
for
etype
,
(
src
,
dst
)
in
node_pair
.
items
():
for
etype
,
(
src
,
dst
)
in
node_pair
s
.
items
():
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
nodes
[
src_type
].
append
(
src
)
nodes
[
dst_type
].
append
(
dst
)
...
...
@@ -72,27 +72,27 @@ class SubgraphSampler(Mapper):
# Unique and compact the collected nodes.
seeds
,
compacted
=
unique_and_compact
(
nodes
)
(
compacted_node_pair
,
compacted_negative_
head
,
compacted_negative_
tail
,
compacted_node_pair
s
,
compacted_negative_
srcs
,
compacted_negative_
dsts
,
)
=
({},
{},
{})
# Map back in same order as collect.
for
etype
,
_
in
node_pair
.
items
():
for
etype
,
_
in
node_pair
s
.
items
():
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
src
=
compacted
[
src_type
].
pop
(
0
)
dst
=
compacted
[
dst_type
].
pop
(
0
)
compacted_node_pair
[
etype
]
=
(
src
,
dst
)
compacted_node_pair
s
[
etype
]
=
(
src
,
dst
)
if
has_neg_src
:
for
etype
,
_
in
neg_src
.
items
():
src_type
,
_
,
_
=
etype_str_to_tuple
(
etype
)
compacted_negative_
head
[
etype
]
=
compacted
[
src_type
].
pop
(
0
)
compacted_negative_
srcs
[
etype
]
=
compacted
[
src_type
].
pop
(
0
)
if
has_neg_dst
:
for
etype
,
_
in
neg_dst
.
items
():
_
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
compacted_negative_
tail
[
etype
]
=
compacted
[
dst_type
].
pop
(
0
)
compacted_negative_
dsts
[
etype
]
=
compacted
[
dst_type
].
pop
(
0
)
else
:
# Collect nodes from all types of input.
nodes
=
list
(
node_pair
)
nodes
=
list
(
node_pair
s
)
if
has_neg_src
:
nodes
.
append
(
neg_src
.
view
(
-
1
))
if
has_neg_dst
:
...
...
@@ -100,17 +100,17 @@ class SubgraphSampler(Mapper):
# Unique and compact the collected nodes.
seeds
,
compacted
=
unique_and_compact
(
nodes
)
# Map back in same order as collect.
compacted_node_pair
=
tuple
(
compacted
[:
2
])
compacted_node_pair
s
=
tuple
(
compacted
[:
2
])
compacted
=
compacted
[
2
:]
if
has_neg_src
:
compacted_negative_
head
=
compacted
.
pop
(
0
)
compacted_negative_
srcs
=
compacted
.
pop
(
0
)
if
has_neg_dst
:
compacted_negative_
tail
=
compacted
.
pop
(
0
)
compacted_negative_
dsts
=
compacted
.
pop
(
0
)
return
(
seeds
,
compacted_node_pair
,
compacted_negative_
head
if
has_neg_src
else
None
,
compacted_negative_
tail
if
has_neg_dst
else
None
,
compacted_node_pair
s
,
compacted_negative_
srcs
if
has_neg_src
else
None
,
compacted_negative_
dsts
if
has_neg_dst
else
None
,
)
def
_sample_subgraphs
(
self
,
seeds
):
...
...
tests/python/pytorch/graphbolt/gb_test_utils.py
View file @
3742d5ff
...
...
@@ -9,12 +9,12 @@ import torch
def
minibatch_node_collator
(
data
):
minibatch
=
gb
.
MiniBatch
(
seed_node
=
data
)
minibatch
=
gb
.
MiniBatch
(
seed_node
s
=
data
)
return
minibatch
def
minibatch_link_collator
(
data
):
minibatch
=
gb
.
MiniBatch
(
node_pair
=
data
)
minibatch
=
gb
.
MiniBatch
(
node_pair
s
=
data
)
return
minibatch
...
...
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
View file @
3742d5ff
...
...
@@ -30,14 +30,14 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
)
# Perform Negative sampling.
for
data
in
negative_sampler
:
src
,
dst
=
data
.
node_pair
label
=
data
.
label
src
,
dst
=
data
.
node_pair
s
label
s
=
data
.
label
s
# Assertation
assert
len
(
src
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
len
(
dst
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
len
(
label
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
torch
.
all
(
torch
.
eq
(
label
[:
batch_size
],
1
))
assert
torch
.
all
(
torch
.
eq
(
label
[
batch_size
:],
0
))
assert
len
(
label
s
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
torch
.
all
(
torch
.
eq
(
label
s
[:
batch_size
],
1
))
assert
torch
.
all
(
torch
.
eq
(
label
s
[
batch_size
:],
0
))
@
pytest
.
mark
.
parametrize
(
"negative_ratio"
,
[
1
,
5
,
10
,
20
])
...
...
@@ -65,8 +65,8 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
)
# Perform Negative sampling.
for
data
in
negative_sampler
:
pos_src
,
pos_dst
=
data
.
node_pair
neg_src
,
neg_dst
=
data
.
negative_
head
,
data
.
negative_
tail
pos_src
,
pos_dst
=
data
.
node_pair
s
neg_src
,
neg_dst
=
data
.
negative_
srcs
,
data
.
negative_
dsts
# Assertation
assert
len
(
pos_src
)
==
batch_size
assert
len
(
pos_dst
)
==
batch_size
...
...
@@ -103,8 +103,8 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
)
# Perform Negative sampling.
for
data
in
negative_sampler
:
pos_src
,
pos_dst
=
data
.
node_pair
neg_src
=
data
.
negative_
head
pos_src
,
pos_dst
=
data
.
node_pair
s
neg_src
=
data
.
negative_
srcs
# Assertation
assert
len
(
pos_src
)
==
batch_size
assert
len
(
pos_dst
)
==
batch_size
...
...
@@ -139,8 +139,8 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
)
# Perform Negative sampling.
for
data
in
negative_sampler
:
pos_src
,
pos_dst
=
data
.
node_pair
neg_dst
=
data
.
negative_
tail
pos_src
,
pos_dst
=
data
.
node_pair
s
neg_dst
=
data
.
negative_
dsts
# Assertation
assert
len
(
pos_src
)
==
batch_size
assert
len
(
pos_dst
)
==
batch_size
...
...
tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py
View file @
3742d5ff
...
...
@@ -78,11 +78,11 @@ def test_OnDiskDataset_TVTSet_ItemSet_names():
train_set:
- type: null
data:
- name: seed_node
- name: seed_node
s
format: numpy
in_memory: true
path:
{
train_ids_path
}
- name: label
- name: label
s
format: numpy
in_memory: true
path:
{
train_labels_path
}
...
...
@@ -104,7 +104,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_names():
for
i
,
(
id
,
label
,
_
)
in
enumerate
(
train_set
):
assert
id
==
train_ids
[
i
]
assert
label
==
train_labels
[
i
]
assert
train_set
.
names
==
(
"seed_node"
,
"label"
,
None
)
assert
train_set
.
names
==
(
"seed_node
s
"
,
"label
s
"
,
None
)
train_set
=
None
...
...
@@ -125,11 +125,11 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_names():
train_set:
- type: "author:writes:paper"
data:
- name: seed_node
- name: seed_node
s
format: numpy
in_memory: true
path:
{
train_ids_path
}
- name: label
- name: label
s
format: numpy
in_memory: true
path:
{
train_labels_path
}
...
...
@@ -154,7 +154,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_names():
id
,
label
,
_
=
item
[
"author:writes:paper"
]
assert
id
==
train_ids
[
i
]
assert
label
==
train_labels
[
i
]
assert
train_set
.
names
==
(
"seed_node"
,
"label"
,
None
)
assert
train_set
.
names
==
(
"seed_node
s
"
,
"label
s
"
,
None
)
train_set
=
None
...
...
@@ -193,32 +193,32 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
train_set:
- type: null
data:
- name: seed_node
- name: seed_node
s
format: numpy
in_memory: true
path:
{
train_ids_path
}
- name: label
- name: label
s
format: numpy
in_memory: true
path:
{
train_labels_path
}
validation_set:
- data:
- name: seed_node
- name: seed_node
s
format: numpy
in_memory: true
path:
{
validation_ids_path
}
- name: label
- name: label
s
format: numpy
in_memory: true
path:
{
validation_labels_path
}
test_set:
- type: null
data:
- name: seed_node
- name: seed_node
s
format: numpy
in_memory: true
path:
{
test_ids_path
}
- name: label
- name: label
s
format: numpy
in_memory: true
path:
{
test_labels_path
}
...
...
@@ -242,7 +242,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
for
i
,
(
id
,
label
)
in
enumerate
(
train_set
):
assert
id
==
train_ids
[
i
]
assert
label
==
train_labels
[
i
]
assert
train_set
.
names
==
(
"seed_node"
,
"label"
)
assert
train_set
.
names
==
(
"seed_node
s
"
,
"label
s
"
)
train_set
=
None
# Verify validation set.
...
...
@@ -252,7 +252,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
for
i
,
(
id
,
label
)
in
enumerate
(
validation_set
):
assert
id
==
validation_ids
[
i
]
assert
label
==
validation_labels
[
i
]
assert
validation_set
.
names
==
(
"seed_node"
,
"label"
)
assert
validation_set
.
names
==
(
"seed_node
s
"
,
"label
s
"
)
validation_set
=
None
# Verify test set.
...
...
@@ -262,7 +262,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
for
i
,
(
id
,
label
)
in
enumerate
(
test_set
):
assert
id
==
test_ids
[
i
]
assert
label
==
test_labels
[
i
]
assert
test_set
.
names
==
(
"seed_node"
,
"label"
)
assert
test_set
.
names
==
(
"seed_node
s
"
,
"label
s
"
)
test_set
=
None
dataset
=
None
...
...
@@ -334,7 +334,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
format: numpy
in_memory: true
path:
{
train_dst_path
}
- name: label
- name: label
s
format: numpy
in_memory: true
path:
{
train_labels_path
}
...
...
@@ -348,7 +348,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
format: numpy
in_memory: true
path:
{
validation_dst_path
}
- name: label
- name: label
s
format: numpy
in_memory: true
path:
{
validation_labels_path
}
...
...
@@ -363,7 +363,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
format: numpy
in_memory: true
path:
{
test_dst_path
}
- name: label
- name: label
s
format: numpy
in_memory: true
path:
{
test_labels_path
}
...
...
@@ -383,7 +383,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
assert
src
==
train_src
[
i
]
assert
dst
==
train_dst
[
i
]
assert
label
==
train_labels
[
i
]
assert
train_set
.
names
==
(
"src"
,
"dst"
,
"label"
)
assert
train_set
.
names
==
(
"src"
,
"dst"
,
"label
s
"
)
train_set
=
None
# Verify validation set.
...
...
@@ -394,7 +394,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
assert
src
==
validation_src
[
i
]
assert
dst
==
validation_dst
[
i
]
assert
label
==
validation_labels
[
i
]
assert
validation_set
.
names
==
(
"src"
,
"dst"
,
"label"
)
assert
validation_set
.
names
==
(
"src"
,
"dst"
,
"label
s
"
)
validation_set
=
None
# Verify test set.
...
...
@@ -405,7 +405,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
assert
src
==
test_src
[
i
]
assert
dst
==
test_dst
[
i
]
assert
label
==
test_labels
[
i
]
assert
test_set
.
names
==
(
"src"
,
"dst"
,
"label"
)
assert
test_set
.
names
==
(
"src"
,
"dst"
,
"label
s
"
)
test_set
=
None
dataset
=
None
...
...
@@ -564,36 +564,36 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
train_set:
- type: paper
data:
- name: seed_node
- name: seed_node
s
format: numpy
in_memory: true
path:
{
train_path
}
- type: author
data:
- name: seed_node
- name: seed_node
s
format: numpy
path:
{
train_path
}
validation_set:
- type: paper
data:
- name: seed_node
- name: seed_node
s
format: numpy
path:
{
validation_path
}
- type: author
data:
- name: seed_node
- name: seed_node
s
format: numpy
path:
{
validation_path
}
test_set:
- type: paper
data:
- name: seed_node
- name: seed_node
s
format: numpy
in_memory: false
path:
{
test_path
}
- type: author
data:
- name: seed_node
- name: seed_node
s
format: numpy
path:
{
test_path
}
"""
...
...
@@ -616,7 +616,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
id
,
label
=
item
[
key
]
assert
id
==
train_ids
[
i
%
1000
]
assert
label
==
train_labels
[
i
%
1000
]
assert
train_set
.
names
==
(
"seed_node"
,)
assert
train_set
.
names
==
(
"seed_node
s
"
,)
train_set
=
None
# Verify validation set.
...
...
@@ -631,7 +631,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
id
,
label
=
item
[
key
]
assert
id
==
validation_ids
[
i
%
1000
]
assert
label
==
validation_labels
[
i
%
1000
]
assert
validation_set
.
names
==
(
"seed_node"
,)
assert
validation_set
.
names
==
(
"seed_node
s
"
,)
validation_set
=
None
# Verify test set.
...
...
@@ -646,7 +646,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
id
,
label
=
item
[
key
]
assert
id
==
test_ids
[
i
%
1000
]
assert
label
==
test_labels
[
i
%
1000
]
assert
test_set
.
names
==
(
"seed_node"
,)
assert
test_set
.
names
==
(
"seed_node
s
"
,)
test_set
=
None
dataset
=
None
...
...
@@ -798,7 +798,7 @@ def test_OnDiskDataset_Feature_heterograph():
path:
{
node_data_paper_path
}
- domain: node
type: paper
name: label
name: label
s
format: numpy
in_memory: true
path:
{
node_data_label_path
}
...
...
@@ -810,7 +810,7 @@ def test_OnDiskDataset_Feature_heterograph():
path:
{
edge_data_writes_path
}
- domain: edge
type: "author:writes:paper"
name: label
name: label
s
format: numpy
in_memory: true
path:
{
edge_data_label_path
}
...
...
@@ -832,7 +832,7 @@ def test_OnDiskDataset_Feature_heterograph():
torch
.
tensor
(
node_data_paper
),
)
assert
torch
.
equal
(
feature_data
.
read
(
"node"
,
"paper"
,
"label"
),
feature_data
.
read
(
"node"
,
"paper"
,
"label
s
"
),
torch
.
tensor
(
node_data_label
),
)
...
...
@@ -842,7 +842,7 @@ def test_OnDiskDataset_Feature_heterograph():
torch
.
tensor
(
edge_data_writes
),
)
assert
torch
.
equal
(
feature_data
.
read
(
"edge"
,
"author:writes:paper"
,
"label"
),
feature_data
.
read
(
"edge"
,
"author:writes:paper"
,
"label
s
"
),
torch
.
tensor
(
edge_data_label
),
)
...
...
@@ -879,7 +879,7 @@ def test_OnDiskDataset_Feature_homograph():
in_memory: false
path:
{
node_data_feat_path
}
- domain: node
name: label
name: label
s
format: numpy
in_memory: true
path:
{
node_data_label_path
}
...
...
@@ -889,7 +889,7 @@ def test_OnDiskDataset_Feature_homograph():
in_memory: false
path:
{
edge_data_feat_path
}
- domain: edge
name: label
name: label
s
format: numpy
in_memory: true
path:
{
edge_data_label_path
}
...
...
@@ -911,7 +911,7 @@ def test_OnDiskDataset_Feature_homograph():
torch
.
tensor
(
node_data_feat
),
)
assert
torch
.
equal
(
feature_data
.
read
(
"node"
,
None
,
"label"
),
feature_data
.
read
(
"node"
,
None
,
"label
s
"
),
torch
.
tensor
(
node_data_label
),
)
...
...
@@ -921,7 +921,7 @@ def test_OnDiskDataset_Feature_homograph():
torch
.
tensor
(
edge_data_feat
),
)
assert
torch
.
equal
(
feature_data
.
read
(
"edge"
,
None
,
"label"
),
feature_data
.
read
(
"edge"
,
None
,
"label
s
"
),
torch
.
tensor
(
edge_data_label
),
)
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
3742d5ff
...
...
@@ -22,7 +22,7 @@ def test_SubgraphSampler_Node(labor):
def
to_link_batch
(
data
):
block
=
gb
.
MiniBatch
(
node_pair
=
data
)
block
=
gb
.
MiniBatch
(
node_pair
s
=
data
)
return
block
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment