Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
3742d5ff
"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "9827e48176ff8e9e9404e558992a2778ba5bf02c"
Unverified
Commit
3742d5ff
authored
Sep 04, 2023
by
Rhett Ying
Committed by
GitHub
Sep 04, 2023
Browse files
[GraphBolt] rename data attributes in MiniBatch (#6274)
parent
72683697
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
131 additions
and
131 deletions
+131
-131
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+4
-4
python/dgl/graphbolt/item_sampler.py
python/dgl/graphbolt/item_sampler.py
+3
-3
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+24
-24
python/dgl/graphbolt/negative_sampler.py
python/dgl/graphbolt/negative_sampler.py
+20
-20
python/dgl/graphbolt/subgraph_sampler.py
python/dgl/graphbolt/subgraph_sampler.py
+28
-28
tests/python/pytorch/graphbolt/gb_test_utils.py
tests/python/pytorch/graphbolt/gb_test_utils.py
+2
-2
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
+11
-11
tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py
tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py
+38
-38
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+1
-1
No files found.
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
3742d5ff
...
@@ -54,7 +54,7 @@ class NeighborSampler(SubgraphSampler):
...
@@ -54,7 +54,7 @@ class NeighborSampler(SubgraphSampler):
>>> import dgl.graphbolt as gb
>>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper
>>> from torchdata.datapipes.iter import Mapper
>>> def minibatch_link_collator(data):
>>> def minibatch_link_collator(data):
... minibatch = gb.MiniBatch(node_pair=data)
... minibatch = gb.MiniBatch(node_pair
s
=data)
... return minibatch
... return minibatch
...
...
>>> from dgl import graphbolt as gb
>>> from dgl import graphbolt as gb
...
@@ -76,7 +76,7 @@ class NeighborSampler(SubgraphSampler):
...
@@ -76,7 +76,7 @@ class NeighborSampler(SubgraphSampler):
>>> subgraph_sampler = gb.NeighborSampler(
>>> subgraph_sampler = gb.NeighborSampler(
...neg_sampler, graph, fanouts)
...neg_sampler, graph, fanouts)
>>> for data in subgraph_sampler:
>>> for data in subgraph_sampler:
... print(data.compacted_node_pair)
... print(data.compacted_node_pair
s
)
... print(len(data.sampled_subgraphs))
... print(len(data.sampled_subgraphs))
(tensor([0, 0, 0]), tensor([1, 0, 2]))
(tensor([0, 0, 0]), tensor([1, 0, 2]))
3
3
...
@@ -166,7 +166,7 @@ class LayerNeighborSampler(NeighborSampler):
...
@@ -166,7 +166,7 @@ class LayerNeighborSampler(NeighborSampler):
>>> import dgl.graphbolt as gb
>>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper
>>> from torchdata.datapipes.iter import Mapper
>>> def minibatch_link_collator(data):
>>> def minibatch_link_collator(data):
... minibatch = gb.MiniBatch(node_pair=data)
... minibatch = gb.MiniBatch(node_pair
s
=data)
... return minibatch
... return minibatch
...
...
>>> from dgl import graphbolt as gb
>>> from dgl import graphbolt as gb
...
@@ -188,7 +188,7 @@ class LayerNeighborSampler(NeighborSampler):
...
@@ -188,7 +188,7 @@ class LayerNeighborSampler(NeighborSampler):
>>> subgraph_sampler = gb.LayerNeighborSampler(
>>> subgraph_sampler = gb.LayerNeighborSampler(
...neg_sampler, graph, fanouts)
...neg_sampler, graph, fanouts)
>>> for data in subgraph_sampler:
>>> for data in subgraph_sampler:
... print(data.compacted_node_pair)
... print(data.compacted_node_pair
s
)
... print(len(data.sampled_subgraphs))
... print(len(data.sampled_subgraphs))
(tensor([0, 0, 0]), tensor([1, 0, 2]))
(tensor([0, 0, 0]), tensor([1, 0, 2]))
3
3
...
...
python/dgl/graphbolt/item_sampler.py
View file @
3742d5ff
...
@@ -17,9 +17,9 @@ __all__ = ["ItemSampler"]
...
@@ -17,9 +17,9 @@ __all__ = ["ItemSampler"]
class
ItemSampler
(
IterDataPipe
):
class
ItemSampler
(
IterDataPipe
):
"""Item Sampler.
"""Item Sampler.
Creates item subset of data which could be node
/edge
IDs, node pairs with
Creates item subset of data which could be node IDs, node pairs with
or
or
without labels,
head/tail/negative_tails, DGLGraphs and heterogeneou
s
without labels,
node pairs with negative sources/destinations, DGLGraph
s
counterparts.
and heterogeneous
counterparts.
Note: This class `ItemSampler` is not decorated with
Note: This class `ItemSampler` is not decorated with
`torchdata.datapipes.functional_datapipe` on purpose. This indicates it
`torchdata.datapipes.functional_datapipe` on purpose. This indicates it
...
...
python/dgl/graphbolt/minibatch.py
View file @
3742d5ff
...
@@ -55,74 +55,74 @@ class MiniBatch:
...
@@ -55,74 +55,74 @@ class MiniBatch:
value should be corresponding heterogeneous node id.
value should be corresponding heterogeneous node id.
"""
"""
seed_node
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
seed_node
s
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""
"""
Representation of seed nodes used for sampling in the graph.
Representation of seed nodes used for sampling in the graph.
- If `seed_node` is a tensor: It indicates the graph is homogeneous.
- If `seed_node
s
` is a tensor: It indicates the graph is homogeneous.
- If `seed_node` is a dictionary: The keys should be node type and the
- If `seed_node
s
` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node ids.
value should be corresponding heterogeneous node ids.
"""
"""
label
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
label
s
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""
"""
L
abels associated with seed nodes in the graph.
l
abels
s
associated with seed nodes in the graph.
- If `label` is a tensor: It indicates the graph is homogeneous. The value
- If `label
s
` is a tensor: It indicates the graph is homogeneous. The value
should be corresponding labels to given 'seed_node' or 'node_pair'.
should be corresponding labels
s
to given 'seed_node
s
' or 'node_pair
s
'.
- If `label` is a dictionary: The keys should be node or edge type and the
- If `label
s
` is a dictionary: The keys should be node or edge type and the
value should be corresponding labels to given 'seed_node' or 'node_pair'.
value should be corresponding labels
s
to given 'seed_node
s
' or 'node_pair
s
'.
"""
"""
node_pair
:
Union
[
node_pair
s
:
Union
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
]
=
None
]
=
None
"""
"""
Representation of seed node pairs utilized in link prediction tasks.
Representation of seed node pairs utilized in link prediction tasks.
- If `node_pair` is a tuple: It indicates a homogeneous graph where each
- If `node_pair
s
` is a tuple: It indicates a homogeneous graph where each
tuple contains two tensors representing source-destination node pairs.
tuple contains two tensors representing source-destination node pairs.
- If `node_pair` is a dictionary: The keys should be edge type, and the
- If `node_pair
s
` is a dictionary: The keys should be edge type, and the
value should be a tuple of tensors representing node pairs of the given
value should be a tuple of tensors representing node pairs of the given
type.
type.
"""
"""
negative_
head
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
negative_
srcs
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""
"""
Representation of negative samples for the head nodes in the link
Representation of negative samples for the head nodes in the link
prediction task.
prediction task.
- If `negative_
head
` is a tensor: It indicates a homogeneous graph.
- If `negative_
srcs
` is a tensor: It indicates a homogeneous graph.
- If `negative_
head
` is a dictionary: The key should be edge type, and the
- If `negative_
srcs
` is a dictionary: The key should be edge type, and the
value should correspond to the negative samples for head nodes of the
value should correspond to the negative samples for head nodes of the
given type.
given type.
"""
"""
negative_
tail
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
negative_
dsts
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""
"""
Representation of negative samples for the tail nodes in the link
Representation of negative samples for the tail nodes in the link
prediction task.
prediction task.
- If `negative_
tail
` is a tensor: It indicates a homogeneous graph.
- If `negative_
dsts
` is a tensor: It indicates a homogeneous graph.
- If `negative_
tail
` is a dictionary: The key should be edge type, and the
- If `negative_
dsts
` is a dictionary: The key should be edge type, and the
value should correspond to the negative samples for head nodes of the
value should correspond to the negative samples for head nodes of the
given type.
given type.
"""
"""
compacted_node_pair
:
Union
[
compacted_node_pair
s
:
Union
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
]
=
None
]
=
None
"""
"""
Representation of compacted node pairs corresponding to 'node_pair', where
Representation of compacted node pairs corresponding to 'node_pair
s
', where
all node ids inside are compacted.
all node ids inside are compacted.
"""
"""
compacted_negative_
head
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
compacted_negative_
srcs
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""
"""
Representation of compacted nodes corresponding to 'negative_
head
', where
Representation of compacted nodes corresponding to 'negative_
srcs
', where
all node ids inside are compacted.
all node ids inside are compacted.
"""
"""
compacted_negative_
tail
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
compacted_negative_
dsts
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""
"""
Representation of compacted nodes corresponding to 'negative_
tail
', where
Representation of compacted nodes corresponding to 'negative_
dsts
', where
all node ids inside are compacted.
all node ids inside are compacted.
"""
"""
...
...
python/dgl/graphbolt/negative_sampler.py
View file @
3742d5ff
...
@@ -44,9 +44,9 @@ class NegativeSampler(Mapper):
...
@@ -44,9 +44,9 @@ class NegativeSampler(Mapper):
Parameters
Parameters
----------
----------
minibatch : MiniBatch
minibatch : MiniBatch
An instance of 'MiniBatch' class requires the 'node_pair' field.
An instance of 'MiniBatch' class requires the 'node_pair
s
' field.
This function is responsible for generating negative edges
This function is responsible for generating negative edges
corresponding to the positive edges defined by the 'node_pair'. In
corresponding to the positive edges defined by the 'node_pair
s
'. In
cases where negative edges already exist, this function will
cases where negative edges already exist, this function will
overwrite them.
overwrite them.
...
@@ -56,21 +56,21 @@ class NegativeSampler(Mapper):
...
@@ -56,21 +56,21 @@ class NegativeSampler(Mapper):
An instance of 'MiniBatch' encompasses both positive and negative
An instance of 'MiniBatch' encompasses both positive and negative
samples.
samples.
"""
"""
node_pairs
=
minibatch
.
node_pair
node_pairs
=
minibatch
.
node_pair
s
assert
node_pairs
is
not
None
assert
node_pairs
is
not
None
if
isinstance
(
node_pairs
,
Mapping
):
if
isinstance
(
node_pairs
,
Mapping
):
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
INDEPENDENT
:
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
INDEPENDENT
:
minibatch
.
label
=
{}
minibatch
.
label
s
=
{}
else
:
else
:
minibatch
.
negative_
head
,
minibatch
.
negative_
tail
=
{},
{}
minibatch
.
negative_
srcs
,
minibatch
.
negative_
dsts
=
{},
{}
for
etype
,
pos_pairs
in
node_pairs
.
items
():
for
etype
,
pos_pairs
in
node_pairs
.
items
():
self
.
_collate
(
self
.
_collate
(
minibatch
,
self
.
_sample_with_etype
(
pos_pairs
,
etype
),
etype
minibatch
,
self
.
_sample_with_etype
(
pos_pairs
,
etype
),
etype
)
)
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
HEAD_CONDITIONED
:
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
HEAD_CONDITIONED
:
minibatch
.
negative_
tail
=
None
minibatch
.
negative_
dsts
=
None
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
TAIL_CONDITIONED
:
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
TAIL_CONDITIONED
:
minibatch
.
negative_
head
=
None
minibatch
.
negative_
srcs
=
None
else
:
else
:
self
.
_collate
(
minibatch
,
self
.
_sample_with_etype
(
node_pairs
))
self
.
_collate
(
minibatch
,
self
.
_sample_with_etype
(
node_pairs
))
return
minibatch
return
minibatch
...
@@ -111,23 +111,23 @@ class NegativeSampler(Mapper):
...
@@ -111,23 +111,23 @@ class NegativeSampler(Mapper):
Canonical edge type.
Canonical edge type.
"""
"""
pos_src
,
pos_dst
=
(
pos_src
,
pos_dst
=
(
minibatch
.
node_pair
[
etype
]
minibatch
.
node_pair
s
[
etype
]
if
etype
is
not
None
if
etype
is
not
None
else
minibatch
.
node_pair
else
minibatch
.
node_pair
s
)
)
neg_src
,
neg_dst
=
neg_pairs
neg_src
,
neg_dst
=
neg_pairs
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
INDEPENDENT
:
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
INDEPENDENT
:
pos_label
=
torch
.
ones_like
(
pos_src
)
pos_label
s
=
torch
.
ones_like
(
pos_src
)
neg_label
=
torch
.
zeros_like
(
neg_src
)
neg_label
s
=
torch
.
zeros_like
(
neg_src
)
src
=
torch
.
cat
([
pos_src
,
neg_src
])
src
=
torch
.
cat
([
pos_src
,
neg_src
])
dst
=
torch
.
cat
([
pos_dst
,
neg_dst
])
dst
=
torch
.
cat
([
pos_dst
,
neg_dst
])
label
=
torch
.
cat
([
pos_label
,
neg_label
])
label
s
=
torch
.
cat
([
pos_label
s
,
neg_label
s
])
if
etype
is
not
None
:
if
etype
is
not
None
:
minibatch
.
node_pair
[
etype
]
=
(
src
,
dst
)
minibatch
.
node_pair
s
[
etype
]
=
(
src
,
dst
)
minibatch
.
label
[
etype
]
=
label
minibatch
.
label
s
[
etype
]
=
label
s
else
:
else
:
minibatch
.
node_pair
=
(
src
,
dst
)
minibatch
.
node_pair
s
=
(
src
,
dst
)
minibatch
.
label
=
label
minibatch
.
label
s
=
label
s
else
:
else
:
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
CONDITIONED
:
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
CONDITIONED
:
neg_src
=
neg_src
.
view
(
-
1
,
self
.
negative_ratio
)
neg_src
=
neg_src
.
view
(
-
1
,
self
.
negative_ratio
)
...
@@ -147,8 +147,8 @@ class NegativeSampler(Mapper):
...
@@ -147,8 +147,8 @@ class NegativeSampler(Mapper):
f
"Unsupported output format
{
self
.
output_format
}
."
f
"Unsupported output format
{
self
.
output_format
}
."
)
)
if
etype
is
not
None
:
if
etype
is
not
None
:
minibatch
.
negative_
head
[
etype
]
=
neg_src
minibatch
.
negative_
srcs
[
etype
]
=
neg_src
minibatch
.
negative_
tail
[
etype
]
=
neg_dst
minibatch
.
negative_
dsts
[
etype
]
=
neg_dst
else
:
else
:
minibatch
.
negative_
head
=
neg_src
minibatch
.
negative_
srcs
=
neg_src
minibatch
.
negative_
tail
=
neg_dst
minibatch
.
negative_
dsts
=
neg_dst
python/dgl/graphbolt/subgraph_sampler.py
View file @
3742d5ff
...
@@ -28,19 +28,19 @@ class SubgraphSampler(Mapper):
...
@@ -28,19 +28,19 @@ class SubgraphSampler(Mapper):
super
().
__init__
(
datapipe
,
self
.
_sample
)
super
().
__init__
(
datapipe
,
self
.
_sample
)
def
_sample
(
self
,
minibatch
):
def
_sample
(
self
,
minibatch
):
if
minibatch
.
node_pair
is
not
None
:
if
minibatch
.
node_pair
s
is
not
None
:
(
(
seeds
,
seeds
,
minibatch
.
compacted_node_pair
,
minibatch
.
compacted_node_pair
s
,
minibatch
.
compacted_negative_
head
,
minibatch
.
compacted_negative_
srcs
,
minibatch
.
compacted_negative_
tail
,
minibatch
.
compacted_negative_
dsts
,
)
=
self
.
_node_pair_preprocess
(
minibatch
)
)
=
self
.
_node_pair
s
_preprocess
(
minibatch
)
elif
minibatch
.
seed_node
is
not
None
:
elif
minibatch
.
seed_node
s
is
not
None
:
seeds
=
minibatch
.
seed_node
seeds
=
minibatch
.
seed_node
s
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Invalid minibatch
{
minibatch
}
: Either 'node_pair' or
\
f
"Invalid minibatch
{
minibatch
}
: Either 'node_pair
s
' or
\
'seed_node' should have a value."
'seed_node
s
' should have a value."
)
)
(
(
minibatch
.
input_nodes
,
minibatch
.
input_nodes
,
...
@@ -48,16 +48,16 @@ class SubgraphSampler(Mapper):
...
@@ -48,16 +48,16 @@ class SubgraphSampler(Mapper):
)
=
self
.
_sample_subgraphs
(
seeds
)
)
=
self
.
_sample_subgraphs
(
seeds
)
return
minibatch
return
minibatch
def
_node_pair_preprocess
(
self
,
minibatch
):
def
_node_pair
s
_preprocess
(
self
,
minibatch
):
node_pair
=
minibatch
.
node_pair
node_pair
s
=
minibatch
.
node_pair
s
neg_src
,
neg_dst
=
minibatch
.
negative_
head
,
minibatch
.
negative_
tail
neg_src
,
neg_dst
=
minibatch
.
negative_
srcs
,
minibatch
.
negative_
dsts
has_neg_src
=
neg_src
is
not
None
has_neg_src
=
neg_src
is
not
None
has_neg_dst
=
neg_dst
is
not
None
has_neg_dst
=
neg_dst
is
not
None
is_heterogeneous
=
isinstance
(
node_pair
,
Dict
)
is_heterogeneous
=
isinstance
(
node_pair
s
,
Dict
)
if
is_heterogeneous
:
if
is_heterogeneous
:
# Collect nodes from all types of input.
# Collect nodes from all types of input.
nodes
=
defaultdict
(
list
)
nodes
=
defaultdict
(
list
)
for
etype
,
(
src
,
dst
)
in
node_pair
.
items
():
for
etype
,
(
src
,
dst
)
in
node_pair
s
.
items
():
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
nodes
[
src_type
].
append
(
src
)
nodes
[
src_type
].
append
(
src
)
nodes
[
dst_type
].
append
(
dst
)
nodes
[
dst_type
].
append
(
dst
)
...
@@ -72,27 +72,27 @@ class SubgraphSampler(Mapper):
...
@@ -72,27 +72,27 @@ class SubgraphSampler(Mapper):
# Unique and compact the collected nodes.
# Unique and compact the collected nodes.
seeds
,
compacted
=
unique_and_compact
(
nodes
)
seeds
,
compacted
=
unique_and_compact
(
nodes
)
(
(
compacted_node_pair
,
compacted_node_pair
s
,
compacted_negative_
head
,
compacted_negative_
srcs
,
compacted_negative_
tail
,
compacted_negative_
dsts
,
)
=
({},
{},
{})
)
=
({},
{},
{})
# Map back in same order as collect.
# Map back in same order as collect.
for
etype
,
_
in
node_pair
.
items
():
for
etype
,
_
in
node_pair
s
.
items
():
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
src
=
compacted
[
src_type
].
pop
(
0
)
src
=
compacted
[
src_type
].
pop
(
0
)
dst
=
compacted
[
dst_type
].
pop
(
0
)
dst
=
compacted
[
dst_type
].
pop
(
0
)
compacted_node_pair
[
etype
]
=
(
src
,
dst
)
compacted_node_pair
s
[
etype
]
=
(
src
,
dst
)
if
has_neg_src
:
if
has_neg_src
:
for
etype
,
_
in
neg_src
.
items
():
for
etype
,
_
in
neg_src
.
items
():
src_type
,
_
,
_
=
etype_str_to_tuple
(
etype
)
src_type
,
_
,
_
=
etype_str_to_tuple
(
etype
)
compacted_negative_
head
[
etype
]
=
compacted
[
src_type
].
pop
(
0
)
compacted_negative_
srcs
[
etype
]
=
compacted
[
src_type
].
pop
(
0
)
if
has_neg_dst
:
if
has_neg_dst
:
for
etype
,
_
in
neg_dst
.
items
():
for
etype
,
_
in
neg_dst
.
items
():
_
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
_
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
compacted_negative_
tail
[
etype
]
=
compacted
[
dst_type
].
pop
(
0
)
compacted_negative_
dsts
[
etype
]
=
compacted
[
dst_type
].
pop
(
0
)
else
:
else
:
# Collect nodes from all types of input.
# Collect nodes from all types of input.
nodes
=
list
(
node_pair
)
nodes
=
list
(
node_pair
s
)
if
has_neg_src
:
if
has_neg_src
:
nodes
.
append
(
neg_src
.
view
(
-
1
))
nodes
.
append
(
neg_src
.
view
(
-
1
))
if
has_neg_dst
:
if
has_neg_dst
:
...
@@ -100,17 +100,17 @@ class SubgraphSampler(Mapper):
...
@@ -100,17 +100,17 @@ class SubgraphSampler(Mapper):
# Unique and compact the collected nodes.
# Unique and compact the collected nodes.
seeds
,
compacted
=
unique_and_compact
(
nodes
)
seeds
,
compacted
=
unique_and_compact
(
nodes
)
# Map back in same order as collect.
# Map back in same order as collect.
compacted_node_pair
=
tuple
(
compacted
[:
2
])
compacted_node_pair
s
=
tuple
(
compacted
[:
2
])
compacted
=
compacted
[
2
:]
compacted
=
compacted
[
2
:]
if
has_neg_src
:
if
has_neg_src
:
compacted_negative_
head
=
compacted
.
pop
(
0
)
compacted_negative_
srcs
=
compacted
.
pop
(
0
)
if
has_neg_dst
:
if
has_neg_dst
:
compacted_negative_
tail
=
compacted
.
pop
(
0
)
compacted_negative_
dsts
=
compacted
.
pop
(
0
)
return
(
return
(
seeds
,
seeds
,
compacted_node_pair
,
compacted_node_pair
s
,
compacted_negative_
head
if
has_neg_src
else
None
,
compacted_negative_
srcs
if
has_neg_src
else
None
,
compacted_negative_
tail
if
has_neg_dst
else
None
,
compacted_negative_
dsts
if
has_neg_dst
else
None
,
)
)
def
_sample_subgraphs
(
self
,
seeds
):
def
_sample_subgraphs
(
self
,
seeds
):
...
...
tests/python/pytorch/graphbolt/gb_test_utils.py
View file @
3742d5ff
...
@@ -9,12 +9,12 @@ import torch
...
@@ -9,12 +9,12 @@ import torch
def
minibatch_node_collator
(
data
):
def
minibatch_node_collator
(
data
):
minibatch
=
gb
.
MiniBatch
(
seed_node
=
data
)
minibatch
=
gb
.
MiniBatch
(
seed_node
s
=
data
)
return
minibatch
return
minibatch
def
minibatch_link_collator
(
data
):
def
minibatch_link_collator
(
data
):
minibatch
=
gb
.
MiniBatch
(
node_pair
=
data
)
minibatch
=
gb
.
MiniBatch
(
node_pair
s
=
data
)
return
minibatch
return
minibatch
...
...
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
View file @
3742d5ff
...
@@ -30,14 +30,14 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
...
@@ -30,14 +30,14 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
)
)
# Perform Negative sampling.
# Perform Negative sampling.
for
data
in
negative_sampler
:
for
data
in
negative_sampler
:
src
,
dst
=
data
.
node_pair
src
,
dst
=
data
.
node_pair
s
label
=
data
.
label
label
s
=
data
.
label
s
# Assertation
# Assertation
assert
len
(
src
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
len
(
src
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
len
(
dst
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
len
(
dst
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
len
(
label
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
len
(
label
s
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
torch
.
all
(
torch
.
eq
(
label
[:
batch_size
],
1
))
assert
torch
.
all
(
torch
.
eq
(
label
s
[:
batch_size
],
1
))
assert
torch
.
all
(
torch
.
eq
(
label
[
batch_size
:],
0
))
assert
torch
.
all
(
torch
.
eq
(
label
s
[
batch_size
:],
0
))
@
pytest
.
mark
.
parametrize
(
"negative_ratio"
,
[
1
,
5
,
10
,
20
])
@
pytest
.
mark
.
parametrize
(
"negative_ratio"
,
[
1
,
5
,
10
,
20
])
...
@@ -65,8 +65,8 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
...
@@ -65,8 +65,8 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
)
)
# Perform Negative sampling.
# Perform Negative sampling.
for
data
in
negative_sampler
:
for
data
in
negative_sampler
:
pos_src
,
pos_dst
=
data
.
node_pair
pos_src
,
pos_dst
=
data
.
node_pair
s
neg_src
,
neg_dst
=
data
.
negative_
head
,
data
.
negative_
tail
neg_src
,
neg_dst
=
data
.
negative_
srcs
,
data
.
negative_
dsts
# Assertation
# Assertation
assert
len
(
pos_src
)
==
batch_size
assert
len
(
pos_src
)
==
batch_size
assert
len
(
pos_dst
)
==
batch_size
assert
len
(
pos_dst
)
==
batch_size
...
@@ -103,8 +103,8 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
...
@@ -103,8 +103,8 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
)
)
# Perform Negative sampling.
# Perform Negative sampling.
for
data
in
negative_sampler
:
for
data
in
negative_sampler
:
pos_src
,
pos_dst
=
data
.
node_pair
pos_src
,
pos_dst
=
data
.
node_pair
s
neg_src
=
data
.
negative_
head
neg_src
=
data
.
negative_
srcs
# Assertation
# Assertation
assert
len
(
pos_src
)
==
batch_size
assert
len
(
pos_src
)
==
batch_size
assert
len
(
pos_dst
)
==
batch_size
assert
len
(
pos_dst
)
==
batch_size
...
@@ -139,8 +139,8 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
...
@@ -139,8 +139,8 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
)
)
# Perform Negative sampling.
# Perform Negative sampling.
for
data
in
negative_sampler
:
for
data
in
negative_sampler
:
pos_src
,
pos_dst
=
data
.
node_pair
pos_src
,
pos_dst
=
data
.
node_pair
s
neg_dst
=
data
.
negative_
tail
neg_dst
=
data
.
negative_
dsts
# Assertation
# Assertation
assert
len
(
pos_src
)
==
batch_size
assert
len
(
pos_src
)
==
batch_size
assert
len
(
pos_dst
)
==
batch_size
assert
len
(
pos_dst
)
==
batch_size
...
...
tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py
View file @
3742d5ff
...
@@ -78,11 +78,11 @@ def test_OnDiskDataset_TVTSet_ItemSet_names():
...
@@ -78,11 +78,11 @@ def test_OnDiskDataset_TVTSet_ItemSet_names():
train_set:
train_set:
- type: null
- type: null
data:
data:
- name: seed_node
- name: seed_node
s
format: numpy
format: numpy
in_memory: true
in_memory: true
path:
{
train_ids_path
}
path:
{
train_ids_path
}
- name: label
- name: label
s
format: numpy
format: numpy
in_memory: true
in_memory: true
path:
{
train_labels_path
}
path:
{
train_labels_path
}
...
@@ -104,7 +104,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_names():
...
@@ -104,7 +104,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_names():
for
i
,
(
id
,
label
,
_
)
in
enumerate
(
train_set
):
for
i
,
(
id
,
label
,
_
)
in
enumerate
(
train_set
):
assert
id
==
train_ids
[
i
]
assert
id
==
train_ids
[
i
]
assert
label
==
train_labels
[
i
]
assert
label
==
train_labels
[
i
]
assert
train_set
.
names
==
(
"seed_node"
,
"label"
,
None
)
assert
train_set
.
names
==
(
"seed_node
s
"
,
"label
s
"
,
None
)
train_set
=
None
train_set
=
None
...
@@ -125,11 +125,11 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_names():
...
@@ -125,11 +125,11 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_names():
train_set:
train_set:
- type: "author:writes:paper"
- type: "author:writes:paper"
data:
data:
- name: seed_node
- name: seed_node
s
format: numpy
format: numpy
in_memory: true
in_memory: true
path:
{
train_ids_path
}
path:
{
train_ids_path
}
- name: label
- name: label
s
format: numpy
format: numpy
in_memory: true
in_memory: true
path:
{
train_labels_path
}
path:
{
train_labels_path
}
...
@@ -154,7 +154,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_names():
...
@@ -154,7 +154,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_names():
id
,
label
,
_
=
item
[
"author:writes:paper"
]
id
,
label
,
_
=
item
[
"author:writes:paper"
]
assert
id
==
train_ids
[
i
]
assert
id
==
train_ids
[
i
]
assert
label
==
train_labels
[
i
]
assert
label
==
train_labels
[
i
]
assert
train_set
.
names
==
(
"seed_node"
,
"label"
,
None
)
assert
train_set
.
names
==
(
"seed_node
s
"
,
"label
s
"
,
None
)
train_set
=
None
train_set
=
None
...
@@ -193,32 +193,32 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
...
@@ -193,32 +193,32 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
train_set:
train_set:
- type: null
- type: null
data:
data:
- name: seed_node
- name: seed_node
s
format: numpy
format: numpy
in_memory: true
in_memory: true
path:
{
train_ids_path
}
path:
{
train_ids_path
}
- name: label
- name: label
s
format: numpy
format: numpy
in_memory: true
in_memory: true
path:
{
train_labels_path
}
path:
{
train_labels_path
}
validation_set:
validation_set:
- data:
- data:
- name: seed_node
- name: seed_node
s
format: numpy
format: numpy
in_memory: true
in_memory: true
path:
{
validation_ids_path
}
path:
{
validation_ids_path
}
- name: label
- name: label
s
format: numpy
format: numpy
in_memory: true
in_memory: true
path:
{
validation_labels_path
}
path:
{
validation_labels_path
}
test_set:
test_set:
- type: null
- type: null
data:
data:
- name: seed_node
- name: seed_node
s
format: numpy
format: numpy
in_memory: true
in_memory: true
path:
{
test_ids_path
}
path:
{
test_ids_path
}
- name: label
- name: label
s
format: numpy
format: numpy
in_memory: true
in_memory: true
path:
{
test_labels_path
}
path:
{
test_labels_path
}
...
@@ -242,7 +242,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
...
@@ -242,7 +242,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
for
i
,
(
id
,
label
)
in
enumerate
(
train_set
):
for
i
,
(
id
,
label
)
in
enumerate
(
train_set
):
assert
id
==
train_ids
[
i
]
assert
id
==
train_ids
[
i
]
assert
label
==
train_labels
[
i
]
assert
label
==
train_labels
[
i
]
assert
train_set
.
names
==
(
"seed_node"
,
"label"
)
assert
train_set
.
names
==
(
"seed_node
s
"
,
"label
s
"
)
train_set
=
None
train_set
=
None
# Verify validation set.
# Verify validation set.
...
@@ -252,7 +252,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
...
@@ -252,7 +252,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
for
i
,
(
id
,
label
)
in
enumerate
(
validation_set
):
for
i
,
(
id
,
label
)
in
enumerate
(
validation_set
):
assert
id
==
validation_ids
[
i
]
assert
id
==
validation_ids
[
i
]
assert
label
==
validation_labels
[
i
]
assert
label
==
validation_labels
[
i
]
assert
validation_set
.
names
==
(
"seed_node"
,
"label"
)
assert
validation_set
.
names
==
(
"seed_node
s
"
,
"label
s
"
)
validation_set
=
None
validation_set
=
None
# Verify test set.
# Verify test set.
...
@@ -262,7 +262,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
...
@@ -262,7 +262,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
for
i
,
(
id
,
label
)
in
enumerate
(
test_set
):
for
i
,
(
id
,
label
)
in
enumerate
(
test_set
):
assert
id
==
test_ids
[
i
]
assert
id
==
test_ids
[
i
]
assert
label
==
test_labels
[
i
]
assert
label
==
test_labels
[
i
]
assert
test_set
.
names
==
(
"seed_node"
,
"label"
)
assert
test_set
.
names
==
(
"seed_node
s
"
,
"label
s
"
)
test_set
=
None
test_set
=
None
dataset
=
None
dataset
=
None
...
@@ -334,7 +334,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
...
@@ -334,7 +334,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
format: numpy
format: numpy
in_memory: true
in_memory: true
path:
{
train_dst_path
}
path:
{
train_dst_path
}
- name: label
- name: label
s
format: numpy
format: numpy
in_memory: true
in_memory: true
path:
{
train_labels_path
}
path:
{
train_labels_path
}
...
@@ -348,7 +348,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
...
@@ -348,7 +348,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
format: numpy
format: numpy
in_memory: true
in_memory: true
path:
{
validation_dst_path
}
path:
{
validation_dst_path
}
- name: label
- name: label
s
format: numpy
format: numpy
in_memory: true
in_memory: true
path:
{
validation_labels_path
}
path:
{
validation_labels_path
}
...
@@ -363,7 +363,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
...
@@ -363,7 +363,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
format: numpy
format: numpy
in_memory: true
in_memory: true
path:
{
test_dst_path
}
path:
{
test_dst_path
}
- name: label
- name: label
s
format: numpy
format: numpy
in_memory: true
in_memory: true
path:
{
test_labels_path
}
path:
{
test_labels_path
}
...
@@ -383,7 +383,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
...
@@ -383,7 +383,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
assert
src
==
train_src
[
i
]
assert
src
==
train_src
[
i
]
assert
dst
==
train_dst
[
i
]
assert
dst
==
train_dst
[
i
]
assert
label
==
train_labels
[
i
]
assert
label
==
train_labels
[
i
]
assert
train_set
.
names
==
(
"src"
,
"dst"
,
"label"
)
assert
train_set
.
names
==
(
"src"
,
"dst"
,
"label
s
"
)
train_set
=
None
train_set
=
None
# Verify validation set.
# Verify validation set.
...
@@ -394,7 +394,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
...
@@ -394,7 +394,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
assert
src
==
validation_src
[
i
]
assert
src
==
validation_src
[
i
]
assert
dst
==
validation_dst
[
i
]
assert
dst
==
validation_dst
[
i
]
assert
label
==
validation_labels
[
i
]
assert
label
==
validation_labels
[
i
]
assert
validation_set
.
names
==
(
"src"
,
"dst"
,
"label"
)
assert
validation_set
.
names
==
(
"src"
,
"dst"
,
"label
s
"
)
validation_set
=
None
validation_set
=
None
# Verify test set.
# Verify test set.
...
@@ -405,7 +405,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
...
@@ -405,7 +405,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
assert
src
==
test_src
[
i
]
assert
src
==
test_src
[
i
]
assert
dst
==
test_dst
[
i
]
assert
dst
==
test_dst
[
i
]
assert
label
==
test_labels
[
i
]
assert
label
==
test_labels
[
i
]
assert
test_set
.
names
==
(
"src"
,
"dst"
,
"label"
)
assert
test_set
.
names
==
(
"src"
,
"dst"
,
"label
s
"
)
test_set
=
None
test_set
=
None
dataset
=
None
dataset
=
None
...
@@ -564,36 +564,36 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
...
@@ -564,36 +564,36 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
train_set:
train_set:
- type: paper
- type: paper
data:
data:
- name: seed_node
- name: seed_node
s
format: numpy
format: numpy
in_memory: true
in_memory: true
path:
{
train_path
}
path:
{
train_path
}
- type: author
- type: author
data:
data:
- name: seed_node
- name: seed_node
s
format: numpy
format: numpy
path:
{
train_path
}
path:
{
train_path
}
validation_set:
validation_set:
- type: paper
- type: paper
data:
data:
- name: seed_node
- name: seed_node
s
format: numpy
format: numpy
path:
{
validation_path
}
path:
{
validation_path
}
- type: author
- type: author
data:
data:
- name: seed_node
- name: seed_node
s
format: numpy
format: numpy
path:
{
validation_path
}
path:
{
validation_path
}
test_set:
test_set:
- type: paper
- type: paper
data:
data:
- name: seed_node
- name: seed_node
s
format: numpy
format: numpy
in_memory: false
in_memory: false
path:
{
test_path
}
path:
{
test_path
}
- type: author
- type: author
data:
data:
- name: seed_node
- name: seed_node
s
format: numpy
format: numpy
path:
{
test_path
}
path:
{
test_path
}
"""
"""
...
@@ -616,7 +616,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
...
@@ -616,7 +616,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
id
,
label
=
item
[
key
]
id
,
label
=
item
[
key
]
assert
id
==
train_ids
[
i
%
1000
]
assert
id
==
train_ids
[
i
%
1000
]
assert
label
==
train_labels
[
i
%
1000
]
assert
label
==
train_labels
[
i
%
1000
]
assert
train_set
.
names
==
(
"seed_node"
,)
assert
train_set
.
names
==
(
"seed_node
s
"
,)
train_set
=
None
train_set
=
None
# Verify validation set.
# Verify validation set.
...
@@ -631,7 +631,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
...
@@ -631,7 +631,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
id
,
label
=
item
[
key
]
id
,
label
=
item
[
key
]
assert
id
==
validation_ids
[
i
%
1000
]
assert
id
==
validation_ids
[
i
%
1000
]
assert
label
==
validation_labels
[
i
%
1000
]
assert
label
==
validation_labels
[
i
%
1000
]
assert
validation_set
.
names
==
(
"seed_node"
,)
assert
validation_set
.
names
==
(
"seed_node
s
"
,)
validation_set
=
None
validation_set
=
None
# Verify test set.
# Verify test set.
...
@@ -646,7 +646,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
...
@@ -646,7 +646,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
id
,
label
=
item
[
key
]
id
,
label
=
item
[
key
]
assert
id
==
test_ids
[
i
%
1000
]
assert
id
==
test_ids
[
i
%
1000
]
assert
label
==
test_labels
[
i
%
1000
]
assert
label
==
test_labels
[
i
%
1000
]
assert
test_set
.
names
==
(
"seed_node"
,)
assert
test_set
.
names
==
(
"seed_node
s
"
,)
test_set
=
None
test_set
=
None
dataset
=
None
dataset
=
None
...
@@ -798,7 +798,7 @@ def test_OnDiskDataset_Feature_heterograph():
...
@@ -798,7 +798,7 @@ def test_OnDiskDataset_Feature_heterograph():
path:
{
node_data_paper_path
}
path:
{
node_data_paper_path
}
- domain: node
- domain: node
type: paper
type: paper
name: label
name: label
s
format: numpy
format: numpy
in_memory: true
in_memory: true
path:
{
node_data_label_path
}
path:
{
node_data_label_path
}
...
@@ -810,7 +810,7 @@ def test_OnDiskDataset_Feature_heterograph():
...
@@ -810,7 +810,7 @@ def test_OnDiskDataset_Feature_heterograph():
path:
{
edge_data_writes_path
}
path:
{
edge_data_writes_path
}
- domain: edge
- domain: edge
type: "author:writes:paper"
type: "author:writes:paper"
name: label
name: label
s
format: numpy
format: numpy
in_memory: true
in_memory: true
path:
{
edge_data_label_path
}
path:
{
edge_data_label_path
}
...
@@ -832,7 +832,7 @@ def test_OnDiskDataset_Feature_heterograph():
...
@@ -832,7 +832,7 @@ def test_OnDiskDataset_Feature_heterograph():
torch
.
tensor
(
node_data_paper
),
torch
.
tensor
(
node_data_paper
),
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
feature_data
.
read
(
"node"
,
"paper"
,
"label"
),
feature_data
.
read
(
"node"
,
"paper"
,
"label
s
"
),
torch
.
tensor
(
node_data_label
),
torch
.
tensor
(
node_data_label
),
)
)
...
@@ -842,7 +842,7 @@ def test_OnDiskDataset_Feature_heterograph():
...
@@ -842,7 +842,7 @@ def test_OnDiskDataset_Feature_heterograph():
torch
.
tensor
(
edge_data_writes
),
torch
.
tensor
(
edge_data_writes
),
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
feature_data
.
read
(
"edge"
,
"author:writes:paper"
,
"label"
),
feature_data
.
read
(
"edge"
,
"author:writes:paper"
,
"label
s
"
),
torch
.
tensor
(
edge_data_label
),
torch
.
tensor
(
edge_data_label
),
)
)
...
@@ -879,7 +879,7 @@ def test_OnDiskDataset_Feature_homograph():
...
@@ -879,7 +879,7 @@ def test_OnDiskDataset_Feature_homograph():
in_memory: false
in_memory: false
path:
{
node_data_feat_path
}
path:
{
node_data_feat_path
}
- domain: node
- domain: node
name: label
name: label
s
format: numpy
format: numpy
in_memory: true
in_memory: true
path:
{
node_data_label_path
}
path:
{
node_data_label_path
}
...
@@ -889,7 +889,7 @@ def test_OnDiskDataset_Feature_homograph():
...
@@ -889,7 +889,7 @@ def test_OnDiskDataset_Feature_homograph():
in_memory: false
in_memory: false
path:
{
edge_data_feat_path
}
path:
{
edge_data_feat_path
}
- domain: edge
- domain: edge
name: label
name: label
s
format: numpy
format: numpy
in_memory: true
in_memory: true
path:
{
edge_data_label_path
}
path:
{
edge_data_label_path
}
...
@@ -911,7 +911,7 @@ def test_OnDiskDataset_Feature_homograph():
...
@@ -911,7 +911,7 @@ def test_OnDiskDataset_Feature_homograph():
torch
.
tensor
(
node_data_feat
),
torch
.
tensor
(
node_data_feat
),
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
feature_data
.
read
(
"node"
,
None
,
"label"
),
feature_data
.
read
(
"node"
,
None
,
"label
s
"
),
torch
.
tensor
(
node_data_label
),
torch
.
tensor
(
node_data_label
),
)
)
...
@@ -921,7 +921,7 @@ def test_OnDiskDataset_Feature_homograph():
...
@@ -921,7 +921,7 @@ def test_OnDiskDataset_Feature_homograph():
torch
.
tensor
(
edge_data_feat
),
torch
.
tensor
(
edge_data_feat
),
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
feature_data
.
read
(
"edge"
,
None
,
"label"
),
feature_data
.
read
(
"edge"
,
None
,
"label
s
"
),
torch
.
tensor
(
edge_data_label
),
torch
.
tensor
(
edge_data_label
),
)
)
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
3742d5ff
...
@@ -22,7 +22,7 @@ def test_SubgraphSampler_Node(labor):
...
@@ -22,7 +22,7 @@ def test_SubgraphSampler_Node(labor):
def
to_link_batch
(
data
):
def
to_link_batch
(
data
):
block
=
gb
.
MiniBatch
(
node_pair
=
data
)
block
=
gb
.
MiniBatch
(
node_pair
s
=
data
)
return
block
return
block
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment