Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
6068dc31
Unverified
Commit
6068dc31
authored
Sep 01, 2023
by
Rhett Ying
Committed by
GitHub
Sep 01, 2023
Browse files
[GraphBolt] add names for each item in ItemSet (#6254)
parent
bbc8ff62
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
296 additions
and
39 deletions
+296
-39
python/dgl/graphbolt/impl/ondisk_dataset.py
python/dgl/graphbolt/impl/ondisk_dataset.py
+4
-2
python/dgl/graphbolt/impl/ondisk_metadata.py
python/dgl/graphbolt/impl/ondisk_metadata.py
+1
-0
python/dgl/graphbolt/itemset.py
python/dgl/graphbolt/itemset.py
+30
-1
tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py
tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py
+184
-36
tests/python/pytorch/graphbolt/test_itemset.py
tests/python/pytorch/graphbolt/test_itemset.py
+77
-0
No files found.
python/dgl/graphbolt/impl/ondisk_dataset.py
View file @
6068dc31
...
@@ -446,7 +446,8 @@ class OnDiskDataset(Dataset):
...
@@ -446,7 +446,8 @@ class OnDiskDataset(Dataset):
tuple
(
tuple
(
read_data
(
data
.
path
,
data
.
format
,
data
.
in_memory
)
read_data
(
data
.
path
,
data
.
format
,
data
.
in_memory
)
for
data
in
tvt_set
[
0
].
data
for
data
in
tvt_set
[
0
].
data
)
),
names
=
tuple
(
data
.
name
for
data
in
tvt_set
[
0
].
data
),
)
)
else
:
else
:
data
=
{}
data
=
{}
...
@@ -455,7 +456,8 @@ class OnDiskDataset(Dataset):
...
@@ -455,7 +456,8 @@ class OnDiskDataset(Dataset):
tuple
(
tuple
(
read_data
(
data
.
path
,
data
.
format
,
data
.
in_memory
)
read_data
(
data
.
path
,
data
.
format
,
data
.
in_memory
)
for
data
in
tvt
.
data
for
data
in
tvt
.
data
)
),
names
=
tuple
(
data
.
name
for
data
in
tvt
.
data
),
)
)
ret
=
ItemSetDict
(
data
)
ret
=
ItemSetDict
(
data
)
return
ret
return
ret
python/dgl/graphbolt/impl/ondisk_metadata.py
View file @
6068dc31
...
@@ -28,6 +28,7 @@ class OnDiskFeatureDataFormat(str, Enum):
...
@@ -28,6 +28,7 @@ class OnDiskFeatureDataFormat(str, Enum):
class
OnDiskTVTSetData
(
pydantic
.
BaseModel
):
class
OnDiskTVTSetData
(
pydantic
.
BaseModel
):
"""Train-Validation-Test set data."""
"""Train-Validation-Test set data."""
name
:
Optional
[
str
]
=
None
format
:
OnDiskFeatureDataFormat
format
:
OnDiskFeatureDataFormat
in_memory
:
Optional
[
bool
]
=
True
in_memory
:
Optional
[
bool
]
=
True
path
:
str
path
:
str
...
...
python/dgl/graphbolt/itemset.py
View file @
6068dc31
...
@@ -47,11 +47,26 @@ class ItemSet:
...
@@ -47,11 +47,26 @@ class ItemSet:
(tensor(4), tensor(9), tensor([18, 19]))]
(tensor(4), tensor(9), tensor([18, 19]))]
"""
"""
def
__init__
(
self
,
items
:
Iterable
or
Tuple
[
Iterable
])
->
None
:
def
__init__
(
self
,
items
:
Iterable
or
Tuple
[
Iterable
],
names
:
str
or
Tuple
[
str
]
=
None
,
)
->
None
:
if
isinstance
(
items
,
tuple
):
if
isinstance
(
items
,
tuple
):
self
.
_items
=
items
self
.
_items
=
items
else
:
else
:
self
.
_items
=
(
items
,)
self
.
_items
=
(
items
,)
if
names
is
not
None
:
if
isinstance
(
names
,
tuple
):
self
.
_names
=
names
else
:
self
.
_names
=
(
names
,)
assert
len
(
self
.
_items
)
==
len
(
self
.
_names
),
(
f
"Number of items (
{
len
(
self
.
_items
)
}
) and "
f
"names (
{
len
(
self
.
_names
)
}
) must match."
)
else
:
self
.
_names
=
None
def
__iter__
(
self
)
->
Iterator
:
def
__iter__
(
self
)
->
Iterator
:
if
len
(
self
.
_items
)
==
1
:
if
len
(
self
.
_items
)
==
1
:
...
@@ -68,6 +83,11 @@ class ItemSet:
...
@@ -68,6 +83,11 @@ class ItemSet:
f
"
{
type
(
self
).
__name__
}
instance doesn't have valid length."
f
"
{
type
(
self
).
__name__
}
instance doesn't have valid length."
)
)
@
property
def
names
(
self
)
->
Tuple
[
str
]:
"""Return the names of the items."""
return
self
.
_names
class
ItemSetDict
:
class
ItemSetDict
:
r
"""An iterable ItemsetDict.
r
"""An iterable ItemsetDict.
...
@@ -127,6 +147,10 @@ class ItemSetDict:
...
@@ -127,6 +147,10 @@ class ItemSetDict:
def
__init__
(
self
,
itemsets
:
Dict
[
str
,
ItemSet
])
->
None
:
def
__init__
(
self
,
itemsets
:
Dict
[
str
,
ItemSet
])
->
None
:
self
.
_itemsets
=
itemsets
self
.
_itemsets
=
itemsets
self
.
_names
=
itemsets
[
list
(
itemsets
.
keys
())[
0
]].
names
assert
all
(
self
.
_names
==
itemset
.
names
for
itemset
in
itemsets
.
values
()
),
"All itemsets must have the same names."
def
__iter__
(
self
)
->
Iterator
:
def
__iter__
(
self
)
->
Iterator
:
for
key
,
itemset
in
self
.
_itemsets
.
items
():
for
key
,
itemset
in
self
.
_itemsets
.
items
():
...
@@ -135,3 +159,8 @@ class ItemSetDict:
...
@@ -135,3 +159,8 @@ class ItemSetDict:
def
__len__
(
self
)
->
int
:
def
__len__
(
self
)
->
int
:
return
sum
(
len
(
itemset
)
for
itemset
in
self
.
_itemsets
.
values
())
return
sum
(
len
(
itemset
)
for
itemset
in
self
.
_itemsets
.
values
())
@
property
def
names
(
self
)
->
Tuple
[
str
]:
"""Return the names of the items."""
return
self
.
_names
tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py
View file @
6068dc31
...
@@ -61,6 +61,103 @@ def test_OnDiskDataset_TVTSet_exceptions():
...
@@ -61,6 +61,103 @@ def test_OnDiskDataset_TVTSet_exceptions():
_
=
gb
.
OnDiskDataset
(
test_dir
).
load
()
_
=
gb
.
OnDiskDataset
(
test_dir
).
load
()
def
test_OnDiskDataset_TVTSet_ItemSet_names
():
"""Test TVTSet which returns ItemSet with IDs, labels and corresponding names."""
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
train_ids
=
np
.
arange
(
1000
)
train_ids_path
=
os
.
path
.
join
(
test_dir
,
"train_ids.npy"
)
np
.
save
(
train_ids_path
,
train_ids
)
train_labels
=
np
.
random
.
randint
(
0
,
10
,
size
=
1000
)
train_labels_path
=
os
.
path
.
join
(
test_dir
,
"train_labels.npy"
)
np
.
save
(
train_labels_path
,
train_labels
)
yaml_content
=
f
"""
tasks:
- name: node_classification
num_classes: 10
train_set:
- type: null
data:
- name: seed_node
format: numpy
in_memory: true
path:
{
train_ids_path
}
- name: label
format: numpy
in_memory: true
path:
{
train_labels_path
}
- format: numpy
in_memory: true
path:
{
train_labels_path
}
"""
os
.
makedirs
(
os
.
path
.
join
(
test_dir
,
"preprocessed"
),
exist_ok
=
True
)
yaml_file
=
os
.
path
.
join
(
test_dir
,
"preprocessed/metadata.yaml"
)
with
open
(
yaml_file
,
"w"
)
as
f
:
f
.
write
(
yaml_content
)
dataset
=
gb
.
OnDiskDataset
(
test_dir
).
load
()
# Verify train set.
train_set
=
dataset
.
tasks
[
0
].
train_set
assert
len
(
train_set
)
==
1000
assert
isinstance
(
train_set
,
gb
.
ItemSet
)
for
i
,
(
id
,
label
,
_
)
in
enumerate
(
train_set
):
assert
id
==
train_ids
[
i
]
assert
label
==
train_labels
[
i
]
assert
train_set
.
names
==
(
"seed_node"
,
"label"
,
None
)
train_set
=
None
def
test_OnDiskDataset_TVTSet_ItemSetDict_names
():
"""Test TVTSet which returns ItemSet with IDs, labels and corresponding names."""
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
train_ids
=
np
.
arange
(
1000
)
train_ids_path
=
os
.
path
.
join
(
test_dir
,
"train_ids.npy"
)
np
.
save
(
train_ids_path
,
train_ids
)
train_labels
=
np
.
random
.
randint
(
0
,
10
,
size
=
1000
)
train_labels_path
=
os
.
path
.
join
(
test_dir
,
"train_labels.npy"
)
np
.
save
(
train_labels_path
,
train_labels
)
yaml_content
=
f
"""
tasks:
- name: node_classification
num_classes: 10
train_set:
- type: "author:writes:paper"
data:
- name: seed_node
format: numpy
in_memory: true
path:
{
train_ids_path
}
- name: label
format: numpy
in_memory: true
path:
{
train_labels_path
}
- format: numpy
in_memory: true
path:
{
train_labels_path
}
"""
os
.
makedirs
(
os
.
path
.
join
(
test_dir
,
"preprocessed"
),
exist_ok
=
True
)
yaml_file
=
os
.
path
.
join
(
test_dir
,
"preprocessed/metadata.yaml"
)
with
open
(
yaml_file
,
"w"
)
as
f
:
f
.
write
(
yaml_content
)
dataset
=
gb
.
OnDiskDataset
(
test_dir
).
load
()
# Verify train set.
train_set
=
dataset
.
tasks
[
0
].
train_set
assert
len
(
train_set
)
==
1000
assert
isinstance
(
train_set
,
gb
.
ItemSetDict
)
for
i
,
item
in
enumerate
(
train_set
):
assert
isinstance
(
item
,
dict
)
assert
"author:writes:paper"
in
item
id
,
label
,
_
=
item
[
"author:writes:paper"
]
assert
id
==
train_ids
[
i
]
assert
label
==
train_labels
[
i
]
assert
train_set
.
names
==
(
"seed_node"
,
"label"
,
None
)
train_set
=
None
def
test_OnDiskDataset_TVTSet_ItemSet_id_label
():
def
test_OnDiskDataset_TVTSet_ItemSet_id_label
():
"""Test TVTSet which returns ItemSet with IDs and labels."""
"""Test TVTSet which returns ItemSet with IDs and labels."""
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
...
@@ -96,27 +193,33 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
...
@@ -96,27 +193,33 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
train_set:
train_set:
- type: null
- type: null
data:
data:
- format: numpy
- name: seed_node
format: numpy
in_memory: true
in_memory: true
path:
{
train_ids_path
}
path:
{
train_ids_path
}
- format: numpy
- name: label
format: numpy
in_memory: true
in_memory: true
path:
{
train_labels_path
}
path:
{
train_labels_path
}
validation_set:
validation_set:
- data:
- data:
- format: numpy
- name: seed_node
format: numpy
in_memory: true
in_memory: true
path:
{
validation_ids_path
}
path:
{
validation_ids_path
}
- format: numpy
- name: label
format: numpy
in_memory: true
in_memory: true
path:
{
validation_labels_path
}
path:
{
validation_labels_path
}
test_set:
test_set:
- type: null
- type: null
data:
data:
- format: numpy
- name: seed_node
format: numpy
in_memory: true
in_memory: true
path:
{
test_ids_path
}
path:
{
test_ids_path
}
- format: numpy
- name: label
format: numpy
in_memory: true
in_memory: true
path:
{
test_labels_path
}
path:
{
test_labels_path
}
"""
"""
...
@@ -139,6 +242,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
...
@@ -139,6 +242,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
for
i
,
(
id
,
label
)
in
enumerate
(
train_set
):
for
i
,
(
id
,
label
)
in
enumerate
(
train_set
):
assert
id
==
train_ids
[
i
]
assert
id
==
train_ids
[
i
]
assert
label
==
train_labels
[
i
]
assert
label
==
train_labels
[
i
]
assert
train_set
.
names
==
(
"seed_node"
,
"label"
)
train_set
=
None
train_set
=
None
# Verify validation set.
# Verify validation set.
...
@@ -148,6 +252,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
...
@@ -148,6 +252,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
for
i
,
(
id
,
label
)
in
enumerate
(
validation_set
):
for
i
,
(
id
,
label
)
in
enumerate
(
validation_set
):
assert
id
==
validation_ids
[
i
]
assert
id
==
validation_ids
[
i
]
assert
label
==
validation_labels
[
i
]
assert
label
==
validation_labels
[
i
]
assert
validation_set
.
names
==
(
"seed_node"
,
"label"
)
validation_set
=
None
validation_set
=
None
# Verify test set.
# Verify test set.
...
@@ -157,6 +262,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
...
@@ -157,6 +262,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
for
i
,
(
id
,
label
)
in
enumerate
(
test_set
):
for
i
,
(
id
,
label
)
in
enumerate
(
test_set
):
assert
id
==
test_ids
[
i
]
assert
id
==
test_ids
[
i
]
assert
label
==
test_labels
[
i
]
assert
label
==
test_labels
[
i
]
assert
test_set
.
names
==
(
"seed_node"
,
"label"
)
test_set
=
None
test_set
=
None
dataset
=
None
dataset
=
None
...
@@ -220,36 +326,45 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
...
@@ -220,36 +326,45 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
train_set:
train_set:
- type: null
- type: null
data:
data:
- format: numpy
- name: src
format: numpy
in_memory: true
in_memory: true
path:
{
train_src_path
}
path:
{
train_src_path
}
- format: numpy
- name: dst
format: numpy
in_memory: true
in_memory: true
path:
{
train_dst_path
}
path:
{
train_dst_path
}
- format: numpy
- name: label
format: numpy
in_memory: true
in_memory: true
path:
{
train_labels_path
}
path:
{
train_labels_path
}
validation_set:
validation_set:
- data:
- data:
- format: numpy
- name: src
format: numpy
in_memory: true
in_memory: true
path:
{
validation_src_path
}
path:
{
validation_src_path
}
- format: numpy
- name: dst
format: numpy
in_memory: true
in_memory: true
path:
{
validation_dst_path
}
path:
{
validation_dst_path
}
- format: numpy
- name: label
format: numpy
in_memory: true
in_memory: true
path:
{
validation_labels_path
}
path:
{
validation_labels_path
}
test_set:
test_set:
- type: null
- type: null
data:
data:
- format: numpy
- name: src
format: numpy
in_memory: true
in_memory: true
path:
{
test_src_path
}
path:
{
test_src_path
}
- format: numpy
- name: dst
format: numpy
in_memory: true
in_memory: true
path:
{
test_dst_path
}
path:
{
test_dst_path
}
- format: numpy
- name: label
format: numpy
in_memory: true
in_memory: true
path:
{
test_labels_path
}
path:
{
test_labels_path
}
"""
"""
...
@@ -268,6 +383,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
...
@@ -268,6 +383,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
assert
src
==
train_src
[
i
]
assert
src
==
train_src
[
i
]
assert
dst
==
train_dst
[
i
]
assert
dst
==
train_dst
[
i
]
assert
label
==
train_labels
[
i
]
assert
label
==
train_labels
[
i
]
assert
train_set
.
names
==
(
"src"
,
"dst"
,
"label"
)
train_set
=
None
train_set
=
None
# Verify validation set.
# Verify validation set.
...
@@ -278,6 +394,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
...
@@ -278,6 +394,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
assert
src
==
validation_src
[
i
]
assert
src
==
validation_src
[
i
]
assert
dst
==
validation_dst
[
i
]
assert
dst
==
validation_dst
[
i
]
assert
label
==
validation_labels
[
i
]
assert
label
==
validation_labels
[
i
]
assert
validation_set
.
names
==
(
"src"
,
"dst"
,
"label"
)
validation_set
=
None
validation_set
=
None
# Verify test set.
# Verify test set.
...
@@ -288,6 +405,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
...
@@ -288,6 +405,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
assert
src
==
test_src
[
i
]
assert
src
==
test_src
[
i
]
assert
dst
==
test_dst
[
i
]
assert
dst
==
test_dst
[
i
]
assert
label
==
test_labels
[
i
]
assert
label
==
test_labels
[
i
]
assert
test_set
.
names
==
(
"src"
,
"dst"
,
"label"
)
test_set
=
None
test_set
=
None
dataset
=
None
dataset
=
None
...
@@ -335,36 +453,45 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
...
@@ -335,36 +453,45 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
train_set:
train_set:
- type: null
- type: null
data:
data:
- format: numpy
- name: src
format: numpy
in_memory: true
in_memory: true
path:
{
train_src_path
}
path:
{
train_src_path
}
- format: numpy
- name: dst
format: numpy
in_memory: true
in_memory: true
path:
{
train_dst_path
}
path:
{
train_dst_path
}
- format: numpy
- name: negative_dst
format: numpy
in_memory: true
in_memory: true
path:
{
train_neg_dst_path
}
path:
{
train_neg_dst_path
}
validation_set:
validation_set:
- data:
- data:
- format: numpy
- name: src
format: numpy
in_memory: true
in_memory: true
path:
{
validation_src_path
}
path:
{
validation_src_path
}
- format: numpy
- name: dst
format: numpy
in_memory: true
in_memory: true
path:
{
validation_dst_path
}
path:
{
validation_dst_path
}
- format: numpy
- name: negative_dst
format: numpy
in_memory: true
in_memory: true
path:
{
validation_neg_dst_path
}
path:
{
validation_neg_dst_path
}
test_set:
test_set:
- type: null
- type: null
data:
data:
- format: numpy
- name: src
format: numpy
in_memory: true
in_memory: true
path:
{
test_src_path
}
path:
{
test_src_path
}
- format: numpy
- name: dst
format: numpy
in_memory: true
in_memory: true
path:
{
test_dst_path
}
path:
{
test_dst_path
}
- format: numpy
- name: negative_dst
format: numpy
in_memory: true
in_memory: true
path:
{
test_neg_dst_path
}
path:
{
test_neg_dst_path
}
"""
"""
...
@@ -383,6 +510,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
...
@@ -383,6 +510,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
assert
src
==
train_src
[
i
]
assert
src
==
train_src
[
i
]
assert
dst
==
train_dst
[
i
]
assert
dst
==
train_dst
[
i
]
assert
torch
.
equal
(
negs
,
torch
.
from_numpy
(
train_neg_dst
[
i
]))
assert
torch
.
equal
(
negs
,
torch
.
from_numpy
(
train_neg_dst
[
i
]))
assert
train_set
.
names
==
(
"src"
,
"dst"
,
"negative_dst"
)
train_set
=
None
train_set
=
None
# Verify validation set.
# Verify validation set.
...
@@ -393,6 +521,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
...
@@ -393,6 +521,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
assert
src
==
validation_src
[
i
]
assert
src
==
validation_src
[
i
]
assert
dst
==
validation_dst
[
i
]
assert
dst
==
validation_dst
[
i
]
assert
torch
.
equal
(
negs
,
torch
.
from_numpy
(
validation_neg_dst
[
i
]))
assert
torch
.
equal
(
negs
,
torch
.
from_numpy
(
validation_neg_dst
[
i
]))
assert
validation_set
.
names
==
(
"src"
,
"dst"
,
"negative_dst"
)
validation_set
=
None
validation_set
=
None
# Verify test set.
# Verify test set.
...
@@ -403,6 +532,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
...
@@ -403,6 +532,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
assert
src
==
test_src
[
i
]
assert
src
==
test_src
[
i
]
assert
dst
==
test_dst
[
i
]
assert
dst
==
test_dst
[
i
]
assert
torch
.
equal
(
negs
,
torch
.
from_numpy
(
test_neg_dst
[
i
]))
assert
torch
.
equal
(
negs
,
torch
.
from_numpy
(
test_neg_dst
[
i
]))
assert
test_set
.
names
==
(
"src"
,
"dst"
,
"negative_dst"
)
test_set
=
None
test_set
=
None
dataset
=
None
dataset
=
None
...
@@ -434,31 +564,37 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
...
@@ -434,31 +564,37 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
train_set:
train_set:
- type: paper
- type: paper
data:
data:
- format: numpy
- name: seed_node
format: numpy
in_memory: true
in_memory: true
path:
{
train_path
}
path:
{
train_path
}
- type: author
- type: author
data:
data:
- format: numpy
- name: seed_node
format: numpy
path:
{
train_path
}
path:
{
train_path
}
validation_set:
validation_set:
- type: paper
- type: paper
data:
data:
- format: numpy
- name: seed_node
format: numpy
path:
{
validation_path
}
path:
{
validation_path
}
- type: author
- type: author
data:
data:
- format: numpy
- name: seed_node
format: numpy
path:
{
validation_path
}
path:
{
validation_path
}
test_set:
test_set:
- type: paper
- type: paper
data:
data:
- format: numpy
- name: seed_node
format: numpy
in_memory: false
in_memory: false
path:
{
test_path
}
path:
{
test_path
}
- type: author
- type: author
data:
data:
- format: numpy
- name: seed_node
format: numpy
path:
{
test_path
}
path:
{
test_path
}
"""
"""
os
.
makedirs
(
os
.
path
.
join
(
test_dir
,
"preprocessed"
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
join
(
test_dir
,
"preprocessed"
),
exist_ok
=
True
)
...
@@ -480,6 +616,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
...
@@ -480,6 +616,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
id
,
label
=
item
[
key
]
id
,
label
=
item
[
key
]
assert
id
==
train_ids
[
i
%
1000
]
assert
id
==
train_ids
[
i
%
1000
]
assert
label
==
train_labels
[
i
%
1000
]
assert
label
==
train_labels
[
i
%
1000
]
assert
train_set
.
names
==
(
"seed_node"
,)
train_set
=
None
train_set
=
None
# Verify validation set.
# Verify validation set.
...
@@ -494,6 +631,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
...
@@ -494,6 +631,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
id
,
label
=
item
[
key
]
id
,
label
=
item
[
key
]
assert
id
==
validation_ids
[
i
%
1000
]
assert
id
==
validation_ids
[
i
%
1000
]
assert
label
==
validation_labels
[
i
%
1000
]
assert
label
==
validation_labels
[
i
%
1000
]
assert
validation_set
.
names
==
(
"seed_node"
,)
validation_set
=
None
validation_set
=
None
# Verify test set.
# Verify test set.
...
@@ -508,6 +646,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
...
@@ -508,6 +646,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
id
,
label
=
item
[
key
]
id
,
label
=
item
[
key
]
assert
id
==
test_ids
[
i
%
1000
]
assert
id
==
test_ids
[
i
%
1000
]
assert
label
==
test_labels
[
i
%
1000
]
assert
label
==
test_labels
[
i
%
1000
]
assert
test_set
.
names
==
(
"seed_node"
,)
test_set
=
None
test_set
=
None
dataset
=
None
dataset
=
None
...
@@ -539,31 +678,37 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
...
@@ -539,31 +678,37 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
train_set:
train_set:
- type: paper
- type: paper
data:
data:
- format: numpy
- name: node_pair
format: numpy
in_memory: true
in_memory: true
path:
{
train_path
}
path:
{
train_path
}
- type: author
- type: author
data:
data:
- format: numpy
- name: node_pair
format: numpy
path:
{
train_path
}
path:
{
train_path
}
validation_set:
validation_set:
- type: paper
- type: paper
data:
data:
- format: numpy
- name: node_pair
format: numpy
path:
{
validation_path
}
path:
{
validation_path
}
- type: author
- type: author
data:
data:
- format: numpy
- name: node_pair
format: numpy
path:
{
validation_path
}
path:
{
validation_path
}
test_set:
test_set:
- type: paper
- type: paper
data:
data:
- format: numpy
- name: node_pair
format: numpy
in_memory: false
in_memory: false
path:
{
test_path
}
path:
{
test_path
}
- type: author
- type: author
data:
data:
- format: numpy
- name: node_pair
format: numpy
path:
{
test_path
}
path:
{
test_path
}
"""
"""
os
.
makedirs
(
os
.
path
.
join
(
test_dir
,
"preprocessed"
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
join
(
test_dir
,
"preprocessed"
),
exist_ok
=
True
)
...
@@ -586,6 +731,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
...
@@ -586,6 +731,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
assert
src
==
train_pairs
[
0
][
i
%
1000
]
assert
src
==
train_pairs
[
0
][
i
%
1000
]
assert
dst
==
train_pairs
[
1
][
i
%
1000
]
assert
dst
==
train_pairs
[
1
][
i
%
1000
]
assert
label
==
train_labels
[
i
%
1000
]
assert
label
==
train_labels
[
i
%
1000
]
assert
train_set
.
names
==
(
"node_pair"
,)
train_set
=
None
train_set
=
None
# Verify validation set.
# Verify validation set.
...
@@ -601,6 +747,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
...
@@ -601,6 +747,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
assert
src
==
validation_pairs
[
0
][
i
%
1000
]
assert
src
==
validation_pairs
[
0
][
i
%
1000
]
assert
dst
==
validation_pairs
[
1
][
i
%
1000
]
assert
dst
==
validation_pairs
[
1
][
i
%
1000
]
assert
label
==
validation_labels
[
i
%
1000
]
assert
label
==
validation_labels
[
i
%
1000
]
assert
validation_set
.
names
==
(
"node_pair"
,)
validation_set
=
None
validation_set
=
None
# Verify test set.
# Verify test set.
...
@@ -616,6 +763,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
...
@@ -616,6 +763,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
assert
src
==
test_pairs
[
0
][
i
%
1000
]
assert
src
==
test_pairs
[
0
][
i
%
1000
]
assert
dst
==
test_pairs
[
1
][
i
%
1000
]
assert
dst
==
test_pairs
[
1
][
i
%
1000
]
assert
label
==
test_labels
[
i
%
1000
]
assert
label
==
test_labels
[
i
%
1000
]
assert
test_set
.
names
==
(
"node_pair"
,)
test_set
=
None
test_set
=
None
dataset
=
None
dataset
=
None
...
...
tests/python/pytorch/graphbolt/test_itemset.py
View file @
6068dc31
import
re
import
dgl
import
dgl
import
pytest
import
pytest
import
torch
import
torch
...
@@ -5,6 +7,81 @@ from dgl import graphbolt as gb
...
@@ -5,6 +7,81 @@ from dgl import graphbolt as gb
from
torch.testing
import
assert_close
from
torch.testing
import
assert_close
def
test_ItemSet_names
():
# ItemSet with single name.
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
5
),
names
=
"seed_node"
)
assert
item_set
.
names
==
(
"seed_node"
,)
# ItemSet with multiple names.
item_set
=
gb
.
ItemSet
(
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
)),
names
=
(
"seed_node"
,
"label"
)
)
assert
item_set
.
names
==
(
"seed_node"
,
"label"
)
# ItemSet with no name.
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
5
))
assert
item_set
.
names
is
None
# ItemSet with mismatched items and names.
with
pytest
.
raises
(
AssertionError
,
match
=
re
.
escape
(
"Number of items (1) and names (2) must match."
),
):
_
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
5
),
names
=
(
"seed_node"
,
"label"
))
def
test_ItemSetDict_names
():
# ItemSetDict with single name.
item_set
=
gb
.
ItemSetDict
(
{
"user"
:
gb
.
ItemSet
(
torch
.
arange
(
0
,
5
),
names
=
"seed_node"
),
"item"
:
gb
.
ItemSet
(
torch
.
arange
(
5
,
10
),
names
=
"seed_node"
),
}
)
assert
item_set
.
names
==
(
"seed_node"
,)
# ItemSetDict with multiple names.
item_set
=
gb
.
ItemSetDict
(
{
"user"
:
gb
.
ItemSet
(
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
)),
names
=
(
"seed_node"
,
"label"
),
),
"item"
:
gb
.
ItemSet
(
(
torch
.
arange
(
5
,
10
),
torch
.
arange
(
10
,
15
)),
names
=
(
"seed_node"
,
"label"
),
),
}
)
assert
item_set
.
names
==
(
"seed_node"
,
"label"
)
# ItemSetDict with no name.
item_set
=
gb
.
ItemSetDict
(
{
"user"
:
gb
.
ItemSet
(
torch
.
arange
(
0
,
5
)),
"item"
:
gb
.
ItemSet
(
torch
.
arange
(
5
,
10
)),
}
)
assert
item_set
.
names
is
None
# ItemSetDict with mismatched items and names.
with
pytest
.
raises
(
AssertionError
,
match
=
re
.
escape
(
"All itemsets must have the same names."
),
):
_
=
gb
.
ItemSetDict
(
{
"user"
:
gb
.
ItemSet
(
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
)),
names
=
(
"seed_node"
,
"label"
),
),
"item"
:
gb
.
ItemSet
(
(
torch
.
arange
(
5
,
10
),),
names
=
(
"seed_node"
,)
),
}
)
def
test_ItemSet_valid_length
():
def
test_ItemSet_valid_length
():
# Single iterable.
# Single iterable.
ids
=
torch
.
arange
(
0
,
5
)
ids
=
torch
.
arange
(
0
,
5
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment