Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
7c51cd16
Unverified
Commit
7c51cd16
authored
Feb 22, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Feb 22, 2024
Browse files
[GraphBolt] Cast sampled data to minimum dtype. (#7131)
parent
8909d1ff
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
85 additions
and
63 deletions
+85
-63
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+17
-2
tests/python/pytorch/graphbolt/test_integration.py
tests/python/pytorch/graphbolt/test_integration.py
+24
-24
tests/python/pytorch/graphbolt/test_minibatch.py
tests/python/pytorch/graphbolt/test_minibatch.py
+44
-37
No files found.
python/dgl/graphbolt/minibatch.py
View file @
7c51cd16
...
@@ -8,7 +8,7 @@ import torch
...
@@ -8,7 +8,7 @@ import torch
import
dgl
import
dgl
from
dgl.utils
import
recursive_apply
from
dgl.utils
import
recursive_apply
from
.base
import
etype_str_to_tuple
,
expand_indptr
from
.base
import
CSCFormatBase
,
etype_str_to_tuple
,
expand_indptr
from
.internal
import
get_attributes
from
.internal
import
get_attributes
from
.sampled_subgraph
import
SampledSubgraph
from
.sampled_subgraph
import
SampledSubgraph
...
@@ -231,6 +231,19 @@ class MiniBatch:
...
@@ -231,6 +231,19 @@ class MiniBatch:
self
.
sampled_subgraphs
[
0
].
sampled_csc
,
Dict
self
.
sampled_subgraphs
[
0
].
sampled_csc
,
Dict
)
)
# casts to minimum dtype in-place and returns self.
def
cast_to_minimum_dtype
(
v
:
CSCFormatBase
):
# Checks if number of vertices and edges fit into an int32.
dtype
=
(
torch
.
int32
if
max
(
v
.
indptr
.
size
(
0
)
-
2
,
v
.
indices
.
size
(
0
))
<=
torch
.
iinfo
(
torch
.
int32
).
max
else
torch
.
int64
)
v
.
indptr
=
v
.
indptr
.
to
(
dtype
)
v
.
indices
=
v
.
indices
.
to
(
dtype
)
return
v
blocks
=
[]
blocks
=
[]
for
subgraph
in
self
.
sampled_subgraphs
:
for
subgraph
in
self
.
sampled_subgraphs
:
original_row_node_ids
=
subgraph
.
original_row_node_ids
original_row_node_ids
=
subgraph
.
original_row_node_ids
...
@@ -242,6 +255,8 @@ class MiniBatch:
...
@@ -242,6 +255,8 @@ class MiniBatch:
original_column_node_ids
is
not
None
original_column_node_ids
is
not
None
),
"Missing `original_column_node_ids` in sampled subgraph."
),
"Missing `original_column_node_ids` in sampled subgraph."
if
is_heterogeneous
:
if
is_heterogeneous
:
for
v
in
subgraph
.
sampled_csc
.
values
():
cast_to_minimum_dtype
(
v
)
sampled_csc
=
{
sampled_csc
=
{
etype_str_to_tuple
(
etype
):
(
etype_str_to_tuple
(
etype
):
(
"csc"
,
"csc"
,
...
@@ -267,7 +282,7 @@ class MiniBatch:
...
@@ -267,7 +282,7 @@ class MiniBatch:
for
ntype
,
nodes
in
original_column_node_ids
.
items
()
for
ntype
,
nodes
in
original_column_node_ids
.
items
()
}
}
else
:
else
:
sampled_csc
=
subgraph
.
sampled_csc
sampled_csc
=
cast_to_minimum_dtype
(
subgraph
.
sampled_csc
)
sampled_csc
=
(
sampled_csc
=
(
"csc"
,
"csc"
,
(
(
...
...
tests/python/pytorch/graphbolt/test_integration.py
View file @
7c51cd16
...
@@ -62,15 +62,15 @@ def test_integration_link_prediction():
...
@@ -62,15 +62,15 @@ def test_integration_link_prediction():
str
(
str
(
"""MiniBatch(seeds=None,
"""MiniBatch(seeds=None,
seed_nodes=None,
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2]
, dtype=torch.int32
),
indices=tensor([0, 4]),
indices=tensor([0, 4]
, dtype=torch.int32
),
),
),
original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_edge_ids=None,
original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]),
),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2]),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2]
, dtype=torch.int32
),
indices=tensor([5, 4]),
indices=tensor([5, 4]
, dtype=torch.int32
),
),
),
original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_edge_ids=None,
original_edge_ids=None,
...
@@ -121,15 +121,15 @@ def test_integration_link_prediction():
...
@@ -121,15 +121,15 @@ def test_integration_link_prediction():
str
(
str
(
"""MiniBatch(seeds=None,
"""MiniBatch(seeds=None,
seed_nodes=None,
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3]
, dtype=torch.int32
),
indices=tensor([4, 1, 0]),
indices=tensor([4, 1, 0]
, dtype=torch.int32
),
),
),
original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]),
original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]),
original_edge_ids=None,
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]),
original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]),
),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3]),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3]
, dtype=torch.int32
),
indices=tensor([4, 4, 0]),
indices=tensor([4, 4, 0]
, dtype=torch.int32
),
),
),
original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]),
original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]),
original_edge_ids=None,
original_edge_ids=None,
...
@@ -180,15 +180,15 @@ def test_integration_link_prediction():
...
@@ -180,15 +180,15 @@ def test_integration_link_prediction():
str
(
str
(
"""MiniBatch(seeds=None,
"""MiniBatch(seeds=None,
seed_nodes=None,
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2]
, dtype=torch.int32
),
indices=tensor([1, 0]),
indices=tensor([1, 0]
, dtype=torch.int32
),
),
),
original_row_node_ids=tensor([5, 4, 0, 1]),
original_row_node_ids=tensor([5, 4, 0, 1]),
original_edge_ids=None,
original_edge_ids=None,
original_column_node_ids=tensor([5, 4, 0, 1]),
original_column_node_ids=tensor([5, 4, 0, 1]),
),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2]),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2]
, dtype=torch.int32
),
indices=tensor([1, 0]),
indices=tensor([1, 0]
, dtype=torch.int32
),
),
),
original_row_node_ids=tensor([5, 4, 0, 1]),
original_row_node_ids=tensor([5, 4, 0, 1]),
original_edge_ids=None,
original_edge_ids=None,
...
@@ -287,15 +287,15 @@ def test_integration_node_classification():
...
@@ -287,15 +287,15 @@ def test_integration_node_classification():
str
(
str
(
"""MiniBatch(seeds=None,
"""MiniBatch(seeds=None,
seed_nodes=None,
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]
, dtype=torch.int32
),
indices=tensor([4, 1, 0, 1]),
indices=tensor([4, 1, 0, 1]
, dtype=torch.int32
),
),
),
original_row_node_ids=tensor([5, 3, 1, 2, 4]),
original_row_node_ids=tensor([5, 3, 1, 2, 4]),
original_edge_ids=None,
original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2]),
original_column_node_ids=tensor([5, 3, 1, 2]),
),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]
, dtype=torch.int32
),
indices=tensor([0, 1, 0, 1]),
indices=tensor([0, 1, 0, 1]
, dtype=torch.int32
),
),
),
original_row_node_ids=tensor([5, 3, 1, 2]),
original_row_node_ids=tensor([5, 3, 1, 2]),
original_edge_ids=None,
original_edge_ids=None,
...
@@ -331,15 +331,15 @@ def test_integration_node_classification():
...
@@ -331,15 +331,15 @@ def test_integration_node_classification():
str
(
str
(
"""MiniBatch(seeds=None,
"""MiniBatch(seeds=None,
seed_nodes=None,
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2]
, dtype=torch.int32
),
indices=tensor([0, 2]),
indices=tensor([0, 2]
, dtype=torch.int32
),
),
),
original_row_node_ids=tensor([3, 4, 0]),
original_row_node_ids=tensor([3, 4, 0]),
original_edge_ids=None,
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0]),
original_column_node_ids=tensor([3, 4, 0]),
),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2]),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2]
, dtype=torch.int32
),
indices=tensor([0, 2]),
indices=tensor([0, 2]
, dtype=torch.int32
),
),
),
original_row_node_ids=tensor([3, 4, 0]),
original_row_node_ids=tensor([3, 4, 0]),
original_edge_ids=None,
original_edge_ids=None,
...
@@ -373,15 +373,15 @@ def test_integration_node_classification():
...
@@ -373,15 +373,15 @@ def test_integration_node_classification():
str
(
str
(
"""MiniBatch(seeds=None,
"""MiniBatch(seeds=None,
seed_nodes=None,
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2]
, dtype=torch.int32
),
indices=tensor([0, 2]),
indices=tensor([0, 2]
, dtype=torch.int32
),
),
),
original_row_node_ids=tensor([5, 4, 0]),
original_row_node_ids=tensor([5, 4, 0]),
original_edge_ids=None,
original_edge_ids=None,
original_column_node_ids=tensor([5, 4]),
original_column_node_ids=tensor([5, 4]),
),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2]),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2]
, dtype=torch.int32
),
indices=tensor([1, 1]),
indices=tensor([1, 1]
, dtype=torch.int32
),
),
),
original_row_node_ids=tensor([5, 4]),
original_row_node_ids=tensor([5, 4]),
original_edge_ids=None,
original_edge_ids=None,
...
...
tests/python/pytorch/graphbolt/test_minibatch.py
View file @
7c51cd16
...
@@ -8,15 +8,17 @@ relation = "A:r:B"
...
@@ -8,15 +8,17 @@ relation = "A:r:B"
reverse_relation
=
"B:rr:A"
reverse_relation
=
"B:rr:A"
def
test_minibatch_representation_homo
():
@
pytest
.
mark
.
parametrize
(
"indptr_dtype"
,
[
torch
.
int32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"indices_dtype"
,
[
torch
.
int32
,
torch
.
int64
])
def
test_minibatch_representation_homo
(
indptr_dtype
,
indices_dtype
):
csc_formats
=
[
csc_formats
=
[
gb
.
CSCFormatBase
(
gb
.
CSCFormatBase
(
indptr
=
torch
.
tensor
([
0
,
1
,
3
,
5
,
6
]),
indptr
=
torch
.
tensor
([
0
,
1
,
3
,
5
,
6
]
,
dtype
=
indptr_dtype
),
indices
=
torch
.
tensor
([
0
,
1
,
2
,
2
,
1
,
2
]),
indices
=
torch
.
tensor
([
0
,
1
,
2
,
2
,
1
,
2
]
,
dtype
=
indices_dtype
),
),
),
gb
.
CSCFormatBase
(
gb
.
CSCFormatBase
(
indptr
=
torch
.
tensor
([
0
,
2
,
3
]),
indptr
=
torch
.
tensor
([
0
,
2
,
3
]
,
dtype
=
indptr_dtype
),
indices
=
torch
.
tensor
([
1
,
2
,
0
]),
indices
=
torch
.
tensor
([
1
,
2
,
0
]
,
dtype
=
indices_dtype
),
),
),
]
]
original_column_node_ids
=
[
original_column_node_ids
=
[
...
@@ -98,15 +100,15 @@ def test_minibatch_representation_homo():
...
@@ -98,15 +100,15 @@ def test_minibatch_representation_homo():
expect_result
=
str
(
expect_result
=
str
(
"""MiniBatch(seeds=None,
"""MiniBatch(seeds=None,
seed_nodes=None,
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6]
, dtype=torch.int32
),
indices=tensor([0, 1, 2, 2, 1, 2]),
indices=tensor([0, 1, 2, 2, 1, 2]
, dtype=torch.int32
),
),
),
original_row_node_ids=tensor([10, 11, 12, 13]),
original_row_node_ids=tensor([10, 11, 12, 13]),
original_edge_ids=tensor([19, 20, 21, 22, 25, 30]),
original_edge_ids=tensor([19, 20, 21, 22, 25, 30]),
original_column_node_ids=tensor([10, 11, 12, 13]),
original_column_node_ids=tensor([10, 11, 12, 13]),
),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 2, 3]),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 2, 3]
, dtype=torch.int32
),
indices=tensor([1, 2, 0]),
indices=tensor([1, 2, 0]
, dtype=torch.int32
),
),
),
original_row_node_ids=tensor([10, 11, 12]),
original_row_node_ids=tensor([10, 11, 12]),
original_edge_ids=tensor([10, 15, 17]),
original_edge_ids=tensor([10, 15, 17]),
...
@@ -119,11 +121,11 @@ def test_minibatch_representation_homo():
...
@@ -119,11 +121,11 @@ def test_minibatch_representation_homo():
indices=tensor([3, 4, 5]),
indices=tensor([3, 4, 5]),
),
),
tensor([0., 1., 2.])),
tensor([0., 1., 2.])),
node_pairs=[CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6]),
node_pairs=[CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6]
, dtype=torch.int32
),
indices=tensor([0, 1, 2, 2, 1, 2]),
indices=tensor([0, 1, 2, 2, 1, 2]
, dtype=torch.int32
),
),
),
CSCFormatBase(indptr=tensor([0, 2, 3]),
CSCFormatBase(indptr=tensor([0, 2, 3]
, dtype=torch.int32
),
indices=tensor([1, 2, 0]),
indices=tensor([1, 2, 0]
, dtype=torch.int32
),
)],
)],
node_features={'x': tensor([5, 0, 2, 1])},
node_features={'x': tensor([5, 0, 2, 1])},
negative_srcs=tensor([[8],
negative_srcs=tensor([[8],
...
@@ -161,21 +163,24 @@ def test_minibatch_representation_homo():
...
@@ -161,21 +163,24 @@ def test_minibatch_representation_homo():
assert
result
==
expect_result
,
print
(
expect_result
,
result
)
assert
result
==
expect_result
,
print
(
expect_result
,
result
)
def
test_minibatch_representation_hetero
():
@
pytest
.
mark
.
parametrize
(
"indptr_dtype"
,
[
torch
.
int32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"indices_dtype"
,
[
torch
.
int32
,
torch
.
int64
])
def
test_minibatch_representation_hetero
(
indptr_dtype
,
indices_dtype
):
csc_formats
=
[
csc_formats
=
[
{
{
relation
:
gb
.
CSCFormatBase
(
relation
:
gb
.
CSCFormatBase
(
indptr
=
torch
.
tensor
([
0
,
1
,
2
,
3
]),
indptr
=
torch
.
tensor
([
0
,
1
,
2
,
3
]
,
dtype
=
indptr_dtype
),
indices
=
torch
.
tensor
([
0
,
1
,
1
]),
indices
=
torch
.
tensor
([
0
,
1
,
1
]
,
dtype
=
indices_dtype
),
),
),
reverse_relation
:
gb
.
CSCFormatBase
(
reverse_relation
:
gb
.
CSCFormatBase
(
indptr
=
torch
.
tensor
([
0
,
0
,
0
,
1
,
2
]),
indptr
=
torch
.
tensor
([
0
,
0
,
0
,
1
,
2
]
,
dtype
=
indptr_dtype
),
indices
=
torch
.
tensor
([
1
,
0
]),
indices
=
torch
.
tensor
([
1
,
0
]
,
dtype
=
indices_dtype
),
),
),
},
},
{
{
relation
:
gb
.
CSCFormatBase
(
relation
:
gb
.
CSCFormatBase
(
indptr
=
torch
.
tensor
([
0
,
1
,
2
]),
indices
=
torch
.
tensor
([
1
,
0
])
indptr
=
torch
.
tensor
([
0
,
1
,
2
],
dtype
=
indptr_dtype
),
indices
=
torch
.
tensor
([
1
,
0
],
dtype
=
indices_dtype
),
)
)
},
},
]
]
...
@@ -250,17 +255,17 @@ def test_minibatch_representation_hetero():
...
@@ -250,17 +255,17 @@ def test_minibatch_representation_hetero():
expect_result
=
str
(
expect_result
=
str
(
"""MiniBatch(seeds=None,
"""MiniBatch(seeds=None,
seed_nodes={'B': tensor([10, 15])},
seed_nodes={'B': tensor([10, 15])},
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]
, dtype=torch.int32
),
indices=tensor([0, 1, 1]),
indices=tensor([0, 1, 1]
, dtype=torch.int32
),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]
, dtype=torch.int32
),
indices=tensor([1, 0]),
indices=tensor([1, 0]
, dtype=torch.int32
),
)},
)},
original_row_node_ids={'A': tensor([ 5, 7, 9, 11]), 'B': tensor([10, 11, 12])},
original_row_node_ids={'A': tensor([ 5, 7, 9, 11]), 'B': tensor([10, 11, 12])},
original_edge_ids={'A:r:B': tensor([19, 20, 21]), 'B:rr:A': tensor([23, 26])},
original_edge_ids={'A:r:B': tensor([19, 20, 21]), 'B:rr:A': tensor([23, 26])},
original_column_node_ids={'B': tensor([10, 11, 12]), 'A': tensor([ 5, 7, 9, 11])},
original_column_node_ids={'B': tensor([10, 11, 12]), 'A': tensor([ 5, 7, 9, 11])},
),
),
SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2]),
SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2]
, dtype=torch.int32
),
indices=tensor([1, 0]),
indices=tensor([1, 0]
, dtype=torch.int32
),
)},
)},
original_row_node_ids={'A': tensor([5, 7]), 'B': tensor([10, 11])},
original_row_node_ids={'A': tensor([5, 7]), 'B': tensor([10, 11])},
original_edge_ids={'A:r:B': tensor([10, 12])},
original_edge_ids={'A:r:B': tensor([10, 12])},
...
@@ -277,13 +282,13 @@ def test_minibatch_representation_hetero():
...
@@ -277,13 +282,13 @@ def test_minibatch_representation_hetero():
indices=tensor([0, 1]),
indices=tensor([0, 1]),
)},
)},
{'B': tensor([2, 5])}),
{'B': tensor([2, 5])}),
node_pairs=[{'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
node_pairs=[{'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]
, dtype=torch.int32
),
indices=tensor([0, 1, 1]),
indices=tensor([0, 1, 1]
, dtype=torch.int32
),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]
, dtype=torch.int32
),
indices=tensor([1, 0]),
indices=tensor([1, 0]
, dtype=torch.int32
),
)},
)},
{'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2]),
{'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2]
, dtype=torch.int32
),
indices=tensor([1, 0]),
indices=tensor([1, 0]
, dtype=torch.int32
),
)}],
)}],
node_features={('A', 'x'): tensor([6, 4, 0, 1])},
node_features={('A', 'x'): tensor([6, 4, 0, 1])},
negative_srcs={'B': tensor([[8],
negative_srcs={'B': tensor([[8],
...
@@ -325,10 +330,12 @@ def test_minibatch_representation_hetero():
...
@@ -325,10 +330,12 @@ def test_minibatch_representation_hetero():
)"""
)"""
)
)
result
=
str
(
minibatch
)
result
=
str
(
minibatch
)
assert
result
==
expect_result
,
print
(
result
)
assert
result
==
expect_result
,
print
(
expect_result
,
result
)
def
test_get_dgl_blocks_homo
():
@
pytest
.
mark
.
parametrize
(
"indptr_dtype"
,
[
torch
.
int32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"indices_dtype"
,
[
torch
.
int32
,
torch
.
int64
])
def
test_get_dgl_blocks_homo
(
indptr_dtype
,
indices_dtype
):
node_pairs
=
[
node_pairs
=
[
(
(
torch
.
tensor
([
0
,
1
,
2
,
2
,
2
,
1
]),
torch
.
tensor
([
0
,
1
,
2
,
2
,
2
,
1
]),
...
@@ -341,12 +348,12 @@ def test_get_dgl_blocks_homo():
...
@@ -341,12 +348,12 @@ def test_get_dgl_blocks_homo():
]
]
csc_formats
=
[
csc_formats
=
[
gb
.
CSCFormatBase
(
gb
.
CSCFormatBase
(
indptr
=
torch
.
tensor
([
0
,
1
,
3
,
5
,
6
]),
indptr
=
torch
.
tensor
([
0
,
1
,
3
,
5
,
6
]
,
dtype
=
indptr_dtype
),
indices
=
torch
.
tensor
([
0
,
1
,
2
,
2
,
1
,
2
]),
indices
=
torch
.
tensor
([
0
,
1
,
2
,
2
,
1
,
2
]
,
dtype
=
indices_dtype
),
),
),
gb
.
CSCFormatBase
(
gb
.
CSCFormatBase
(
indptr
=
torch
.
tensor
([
0
,
1
,
3
]),
indptr
=
torch
.
tensor
([
0
,
1
,
3
]
,
dtype
=
indptr_dtype
),
indices
=
torch
.
tensor
([
0
,
1
,
2
]),
indices
=
torch
.
tensor
([
0
,
1
,
2
]
,
dtype
=
indices_dtype
),
),
),
]
]
original_column_node_ids
=
[
original_column_node_ids
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment