Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
e5ddc62b
Unverified
Commit
e5ddc62b
authored
Jul 07, 2023
by
Rhett Ying
Committed by
GitHub
Jul 07, 2023
Browse files
[GraphBolt] add support to generate TVT in ItemSet or ItemSetDict format (#5958)
parent
ca36441b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
516 additions
and
40 deletions
+516
-40
python/dgl/graphbolt/dataset.py
python/dgl/graphbolt/dataset.py
+57
-20
python/dgl/graphbolt/utils.py
python/dgl/graphbolt/utils.py
+30
-0
tests/python/pytorch/graphbolt/test_dataset.py
tests/python/pytorch/graphbolt/test_dataset.py
+429
-20
No files found.
python/dgl/graphbolt/dataset.py
View file @
e5ddc62b
...
...
@@ -7,6 +7,7 @@ import pydantic_yaml
from
.feature_store
import
FeatureStore
from
.itemset
import
ItemSet
,
ItemSetDict
from
.utils
import
read_data
,
tensor_to_tuple
__all__
=
[
"Dataset"
,
"OnDiskDataset"
]
...
...
@@ -34,16 +35,16 @@ class Dataset:
generate a subgraph.
"""
def
train_set
(
self
)
->
ItemSet
or
ItemSetDict
:
"""Return the training set."""
def
train_set
s
(
self
)
->
List
[
ItemSet
]
or
List
[
ItemSetDict
]
:
"""Return the training set
s
."""
raise
NotImplementedError
def
validation_set
(
self
)
->
ItemSet
or
ItemSetDict
:
"""Return the validation set."""
def
validation_set
s
(
self
)
->
List
[
ItemSet
]
or
List
[
ItemSetDict
]
:
"""Return the validation set
s
."""
raise
NotImplementedError
def
test_set
(
self
)
->
ItemSet
or
ItemSetDict
:
"""Return the test set."""
def
test_set
s
(
self
)
->
List
[
ItemSet
]
or
List
[
ItemSetDict
]
:
"""Return the test set
s
."""
raise
NotImplementedError
def
graph
(
self
)
->
object
:
...
...
@@ -65,8 +66,9 @@ class OnDiskDataFormatEnum(pydantic_yaml.YamlStrEnum):
class
OnDiskTVTSet
(
pydantic
.
BaseModel
):
"""Train-Validation-Test set."""
type_name
:
str
type_name
:
Optional
[
str
]
format
:
OnDiskDataFormatEnum
in_memory
:
Optional
[
bool
]
=
True
path
:
str
...
...
@@ -77,9 +79,9 @@ class OnDiskMetaData(pydantic_yaml.YamlModel):
is a list of list of ``OnDiskTVTSet``.
"""
train_set
:
Optional
[
List
[
List
[
OnDiskTVTSet
]]]
validation_set
:
Optional
[
List
[
List
[
OnDiskTVTSet
]]]
test_set
:
Optional
[
List
[
List
[
OnDiskTVTSet
]]]
train_set
s
:
Optional
[
List
[
List
[
OnDiskTVTSet
]]]
validation_set
s
:
Optional
[
List
[
List
[
OnDiskTVTSet
]]]
test_set
s
:
Optional
[
List
[
List
[
OnDiskTVTSet
]]]
class
OnDiskDataset
(
Dataset
):
...
...
@@ -95,17 +97,20 @@ class OnDiskDataset(Dataset):
.. code-block:: yaml
train_set:
- - type_name: paper
train_set
s
:
- - type_name: paper
# could be null for homogeneous graph.
format: numpy
in_memory: true # If not specified, default to true.
path: set/paper-train.npy
validation_set:
validation_set
s
:
- - type_name: paper
format: numpy
in_memory: true
path: set/paper-validation.npy
test_set:
test_set
s
:
- - type_name: paper
format: numpy
in_memory: true
path: set/paper-test.npy
Parameters
...
...
@@ -117,18 +122,21 @@ class OnDiskDataset(Dataset):
def
__init__
(
self
,
path
:
str
)
->
None
:
with
open
(
path
,
"r"
)
as
f
:
self
.
_meta
=
OnDiskMetaData
.
parse_raw
(
f
.
read
(),
proto
=
"yaml"
)
self
.
_train_sets
=
self
.
_init_tvt_sets
(
self
.
_meta
.
train_sets
)
self
.
_validation_sets
=
self
.
_init_tvt_sets
(
self
.
_meta
.
validation_sets
)
self
.
_test_sets
=
self
.
_init_tvt_sets
(
self
.
_meta
.
test_sets
)
def
train_set
(
self
)
->
ItemSet
or
ItemSetDict
:
def
train_set
s
(
self
)
->
List
[
ItemSet
]
or
List
[
ItemSetDict
]
:
"""Return the training set."""
r
aise
NotImplementedError
r
eturn
self
.
_train_sets
def
validation_set
(
self
)
->
ItemSet
or
ItemSetDict
:
def
validation_set
s
(
self
)
->
List
[
ItemSet
]
or
List
[
ItemSetDict
]
:
"""Return the validation set."""
r
aise
NotImplementedError
r
eturn
self
.
_validation_sets
def
test_set
(
self
)
->
ItemSet
or
ItemSetDict
:
def
test_set
s
(
self
)
->
List
[
ItemSet
]
or
List
[
ItemSetDict
]
:
"""Return the test set."""
r
aise
NotImplementedError
r
eturn
self
.
_test_sets
def
graph
(
self
)
->
object
:
"""Return the graph."""
...
...
@@ -137,3 +145,32 @@ class OnDiskDataset(Dataset):
def
feature
(
self
)
->
FeatureStore
:
"""Return the feature."""
raise
NotImplementedError
def
_init_tvt_sets
(
self
,
tvt_sets
:
List
[
List
[
OnDiskTVTSet
]]
)
->
List
[
ItemSet
]
or
List
[
ItemSetDict
]:
"""Initialize the TVT sets."""
if
(
tvt_sets
is
None
)
or
(
len
(
tvt_sets
)
==
0
):
return
None
ret
=
[]
for
tvt_set
in
tvt_sets
:
if
(
tvt_set
is
None
)
or
(
len
(
tvt_set
)
==
0
):
ret
.
append
(
None
)
if
tvt_set
[
0
].
type_name
is
None
:
assert
(
len
(
tvt_set
)
==
1
),
"Only one TVT set is allowed if type_name is not specified."
data
=
read_data
(
tvt_set
[
0
].
path
,
tvt_set
[
0
].
format
,
tvt_set
[
0
].
in_memory
)
ret
.
append
(
ItemSet
(
tensor_to_tuple
(
data
)))
else
:
data
=
{}
for
tvt
in
tvt_set
:
data
[
tvt
.
type_name
]
=
ItemSet
(
tensor_to_tuple
(
read_data
(
tvt
.
path
,
tvt
.
format
,
tvt
.
in_memory
)
)
)
ret
.
append
(
ItemSetDict
(
data
))
return
ret
python/dgl/graphbolt/utils.py
0 → 100644
View file @
e5ddc62b
"""Utility functions for GraphBolt."""
import
numpy
as
np
import
torch
def
_read_torch_data
(
path
):
return
torch
.
load
(
path
)
def
_read_numpy_data
(
path
,
in_memory
=
True
):
if
in_memory
:
return
torch
.
from_numpy
(
np
.
load
(
path
))
return
torch
.
as_tensor
(
np
.
load
(
path
,
mmap_mode
=
"r+"
))
def
read_data
(
path
,
fmt
,
in_memory
=
True
):
"""Read data from disk."""
if
fmt
==
"torch"
:
return
_read_torch_data
(
path
)
elif
fmt
==
"numpy"
:
return
_read_numpy_data
(
path
,
in_memory
=
in_memory
)
else
:
raise
RuntimeError
(
f
"Unsupported format:
{
fmt
}
"
)
def
tensor_to_tuple
(
data
):
"""Split a torch.Tensor in column-wise to a tuple."""
assert
isinstance
(
data
,
torch
.
Tensor
),
"data must be a torch.Tensor"
return
tuple
(
data
.
t
())
tests/python/pytorch/graphbolt/test_dataset.py
View file @
e5ddc62b
import
os
import
tempfile
import
numpy
as
np
import
pydantic
import
pytest
from
dgl
import
graphbolt
as
gb
...
...
@@ -9,45 +11,452 @@ from dgl import graphbolt as gb
def
test_Dataset
():
dataset
=
gb
.
Dataset
()
with
pytest
.
raises
(
NotImplementedError
):
_
=
dataset
.
train_set
()
_
=
dataset
.
train_set
s
()
with
pytest
.
raises
(
NotImplementedError
):
_
=
dataset
.
validation_set
()
_
=
dataset
.
validation_set
s
()
with
pytest
.
raises
(
NotImplementedError
):
_
=
dataset
.
test_set
()
_
=
dataset
.
test_set
s
()
with
pytest
.
raises
(
NotImplementedError
):
_
=
dataset
.
graph
()
with
pytest
.
raises
(
NotImplementedError
):
_
=
dataset
.
feature
()
def
test_OnDiskDataset_TVTSet
():
"""Test
OnDiskDataset with
TVTSet."""
def
test_OnDiskDataset_TVTSet
_exceptions
():
"""Test
excpetions thrown when parsing
TVTSet."""
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
yaml_file
=
os
.
path
.
join
(
test_dir
,
"test.yaml"
)
# Case 1: ``format`` is invalid.
yaml_content
=
"""
train_set:
train_set
s
:
- - type_name: paper
format: torch
format: torch
_invalid
path: set/paper-train.pt
- type_name: 'paper:cites:paper'
format: numpy
path: set/cites-train.pt
"""
yaml_file
=
os
.
path
.
join
(
test_dir
,
"test.yaml"
)
with
open
(
yaml_file
,
"w"
)
as
f
:
f
.
write
(
yaml_content
)
with
pytest
.
raises
(
pydantic
.
ValidationError
):
_
=
gb
.
OnDiskDataset
(
yaml_file
)
#
Invalid format
.
#
Case 2: ``type_name`` is not specified while multiple TVT sets are specified
.
yaml_content
=
"""
train_set:
- - type_name:
paper
format:
torch_invalid
path: set/
paper-
train.
pt
- type_name:
'paper:cites:paper'
format: numpy
_invalid
path: set/
cites-
train.
pt
train_set
s
:
- - type_name:
null
format:
numpy
path: set/train.
npy
- type_name:
null
format: numpy
path: set/train.
npy
"""
with
open
(
yaml_file
,
"w"
)
as
f
:
f
.
write
(
yaml_content
)
with
pytest
.
raises
(
pydantic
.
ValidationError
):
with
pytest
.
raises
(
AssertionError
,
match
=
r
"Only one TVT set is allowed if type_name is not specified."
,
):
_
=
gb
.
OnDiskDataset
(
yaml_file
)
def
test_OnDiskDataset_TVTSet_ItemSet_id_label
():
"""Test TVTSet which returns ItemSet with IDs and labels."""
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
train_ids
=
np
.
arange
(
1000
)
train_labels
=
np
.
random
.
randint
(
0
,
10
,
size
=
1000
)
train_data
=
np
.
vstack
([
train_ids
,
train_labels
]).
T
train_path
=
os
.
path
.
join
(
test_dir
,
"train.npy"
)
np
.
save
(
train_path
,
train_data
)
validation_ids
=
np
.
arange
(
1000
,
2000
)
validation_labels
=
np
.
random
.
randint
(
0
,
10
,
size
=
1000
)
validation_data
=
np
.
vstack
([
validation_ids
,
validation_labels
]).
T
validation_path
=
os
.
path
.
join
(
test_dir
,
"validation.npy"
)
np
.
save
(
validation_path
,
validation_data
)
test_ids
=
np
.
arange
(
2000
,
3000
)
test_labels
=
np
.
random
.
randint
(
0
,
10
,
size
=
1000
)
test_data
=
np
.
vstack
([
test_ids
,
test_labels
]).
T
test_path
=
os
.
path
.
join
(
test_dir
,
"test.npy"
)
np
.
save
(
test_path
,
test_data
)
# Case 1:
# all TVT sets are specified.
# ``type_name`` is not specified or specified as ``null``.
# ``in_memory`` could be ``true`` and ``false``.
yaml_content
=
f
"""
train_sets:
- - type_name: null
format: numpy
in_memory: true
path:
{
train_path
}
- - type_name: null
format: numpy
path:
{
train_path
}
validation_sets:
- - format: numpy
path:
{
validation_path
}
- - type_name: null
format: numpy
path:
{
validation_path
}
test_sets:
- - type_name: null
format: numpy
in_memory: false
path:
{
test_path
}
- - type_name: null
format: numpy
path:
{
test_path
}
"""
yaml_file
=
os
.
path
.
join
(
test_dir
,
"test.yaml"
)
with
open
(
yaml_file
,
"w"
)
as
f
:
f
.
write
(
yaml_content
)
dataset
=
gb
.
OnDiskDataset
(
yaml_file
)
# Verify train set.
train_sets
=
dataset
.
train_sets
()
assert
len
(
train_sets
)
==
2
for
train_set
in
train_sets
:
assert
len
(
train_set
)
==
1000
assert
isinstance
(
train_set
,
gb
.
ItemSet
)
for
i
,
(
id
,
label
)
in
enumerate
(
train_set
):
assert
id
==
train_ids
[
i
]
assert
label
==
train_labels
[
i
]
train_sets
=
None
# Verify validation set.
validation_sets
=
dataset
.
validation_sets
()
assert
len
(
validation_sets
)
==
2
for
validation_set
in
validation_sets
:
assert
len
(
validation_set
)
==
1000
assert
isinstance
(
validation_set
,
gb
.
ItemSet
)
for
i
,
(
id
,
label
)
in
enumerate
(
validation_set
):
assert
id
==
validation_ids
[
i
]
assert
label
==
validation_labels
[
i
]
validation_sets
=
None
# Verify test set.
test_sets
=
dataset
.
test_sets
()
assert
len
(
test_sets
)
==
2
for
test_set
in
test_sets
:
assert
len
(
test_set
)
==
1000
assert
isinstance
(
test_set
,
gb
.
ItemSet
)
for
i
,
(
id
,
label
)
in
enumerate
(
test_set
):
assert
id
==
test_ids
[
i
]
assert
label
==
test_labels
[
i
]
test_sets
=
None
dataset
=
None
# Case 2: Some TVT sets are None.
yaml_content
=
f
"""
train_sets:
- - type_name: null
format: numpy
path:
{
train_path
}
"""
yaml_file
=
os
.
path
.
join
(
test_dir
,
"test.yaml"
)
with
open
(
yaml_file
,
"w"
)
as
f
:
f
.
write
(
yaml_content
)
dataset
=
gb
.
OnDiskDataset
(
yaml_file
)
assert
dataset
.
train_sets
()
is
not
None
assert
dataset
.
validation_sets
()
is
None
assert
dataset
.
test_sets
()
is
None
dataset
=
None
def
test_OnDiskDataset_TVTSet_ItemSet_node_pair_label
():
"""Test TVTSet which returns ItemSet with IDs and labels."""
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
train_pairs
=
(
np
.
arange
(
1000
),
np
.
arange
(
1000
,
2000
))
train_labels
=
np
.
random
.
randint
(
0
,
10
,
size
=
1000
)
train_data
=
np
.
vstack
([
train_pairs
,
train_labels
]).
T
train_path
=
os
.
path
.
join
(
test_dir
,
"train.npy"
)
np
.
save
(
train_path
,
train_data
)
validation_pairs
=
(
np
.
arange
(
1000
,
2000
),
np
.
arange
(
2000
,
3000
))
validation_labels
=
np
.
random
.
randint
(
0
,
10
,
size
=
1000
)
validation_data
=
np
.
vstack
([
validation_pairs
,
validation_labels
]).
T
validation_path
=
os
.
path
.
join
(
test_dir
,
"validation.npy"
)
np
.
save
(
validation_path
,
validation_data
)
test_pairs
=
(
np
.
arange
(
2000
,
3000
),
np
.
arange
(
3000
,
4000
))
test_labels
=
np
.
random
.
randint
(
0
,
10
,
size
=
1000
)
test_data
=
np
.
vstack
([
test_pairs
,
test_labels
]).
T
test_path
=
os
.
path
.
join
(
test_dir
,
"test.npy"
)
np
.
save
(
test_path
,
test_data
)
yaml_content
=
f
"""
train_sets:
- - type_name: null
format: numpy
in_memory: true
path:
{
train_path
}
- - type_name: null
format: numpy
path:
{
train_path
}
validation_sets:
- - format: numpy
path:
{
validation_path
}
- - type_name: null
format: numpy
path:
{
validation_path
}
test_sets:
- - type_name: null
format: numpy
in_memory: false
path:
{
test_path
}
- - type_name: null
format: numpy
path:
{
test_path
}
"""
yaml_file
=
os
.
path
.
join
(
test_dir
,
"test.yaml"
)
with
open
(
yaml_file
,
"w"
)
as
f
:
f
.
write
(
yaml_content
)
dataset
=
gb
.
OnDiskDataset
(
yaml_file
)
# Verify train set.
train_sets
=
dataset
.
train_sets
()
assert
len
(
train_sets
)
==
2
for
train_set
in
train_sets
:
assert
len
(
train_set
)
==
1000
assert
isinstance
(
train_set
,
gb
.
ItemSet
)
for
i
,
(
src
,
dst
,
label
)
in
enumerate
(
train_set
):
assert
src
==
train_pairs
[
0
][
i
]
assert
dst
==
train_pairs
[
1
][
i
]
assert
label
==
train_labels
[
i
]
train_sets
=
None
# Verify validation set.
validation_sets
=
dataset
.
validation_sets
()
assert
len
(
validation_sets
)
==
2
for
validation_set
in
validation_sets
:
assert
len
(
validation_set
)
==
1000
assert
isinstance
(
validation_set
,
gb
.
ItemSet
)
for
i
,
(
src
,
dst
,
label
)
in
enumerate
(
validation_set
):
assert
src
==
validation_pairs
[
0
][
i
]
assert
dst
==
validation_pairs
[
1
][
i
]
assert
label
==
validation_labels
[
i
]
validation_sets
=
None
# Verify test set.
test_sets
=
dataset
.
test_sets
()
assert
len
(
test_sets
)
==
2
for
test_set
in
test_sets
:
assert
len
(
test_set
)
==
1000
assert
isinstance
(
test_set
,
gb
.
ItemSet
)
for
i
,
(
src
,
dst
,
label
)
in
enumerate
(
test_set
):
assert
src
==
test_pairs
[
0
][
i
]
assert
dst
==
test_pairs
[
1
][
i
]
assert
label
==
test_labels
[
i
]
test_sets
=
None
dataset
=
None
def
test_OnDiskDataset_TVTSet_ItemSetDict_id_label
():
"""Test TVTSet which returns ItemSetDict with IDs and labels."""
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
train_ids
=
np
.
arange
(
1000
)
train_labels
=
np
.
random
.
randint
(
0
,
10
,
size
=
1000
)
train_data
=
np
.
vstack
([
train_ids
,
train_labels
]).
T
train_path
=
os
.
path
.
join
(
test_dir
,
"train.npy"
)
np
.
save
(
train_path
,
train_data
)
validation_ids
=
np
.
arange
(
1000
,
2000
)
validation_labels
=
np
.
random
.
randint
(
0
,
10
,
size
=
1000
)
validation_data
=
np
.
vstack
([
validation_ids
,
validation_labels
]).
T
validation_path
=
os
.
path
.
join
(
test_dir
,
"validation.npy"
)
np
.
save
(
validation_path
,
validation_data
)
test_ids
=
np
.
arange
(
2000
,
3000
)
test_labels
=
np
.
random
.
randint
(
0
,
10
,
size
=
1000
)
test_data
=
np
.
vstack
([
test_ids
,
test_labels
]).
T
test_path
=
os
.
path
.
join
(
test_dir
,
"test.npy"
)
np
.
save
(
test_path
,
test_data
)
yaml_content
=
f
"""
train_sets:
- - type_name: paper
format: numpy
in_memory: true
path:
{
train_path
}
- - type_name: author
format: numpy
path:
{
train_path
}
validation_sets:
- - type_name: paper
format: numpy
path:
{
validation_path
}
- - type_name: author
format: numpy
path:
{
validation_path
}
test_sets:
- - type_name: paper
format: numpy
in_memory: false
path:
{
test_path
}
- - type_name: author
format: numpy
path:
{
test_path
}
"""
yaml_file
=
os
.
path
.
join
(
test_dir
,
"test.yaml"
)
with
open
(
yaml_file
,
"w"
)
as
f
:
f
.
write
(
yaml_content
)
dataset
=
gb
.
OnDiskDataset
(
yaml_file
)
# Verify train set.
train_sets
=
dataset
.
train_sets
()
assert
len
(
train_sets
)
==
2
for
train_set
in
train_sets
:
assert
len
(
train_set
)
==
1000
assert
isinstance
(
train_set
,
gb
.
ItemSetDict
)
for
i
,
item
in
enumerate
(
train_set
):
assert
isinstance
(
item
,
dict
)
assert
len
(
item
)
==
1
key
=
list
(
item
.
keys
())[
0
]
assert
key
in
[
"paper"
,
"author"
]
id
,
label
=
item
[
key
]
assert
id
==
train_ids
[
i
]
assert
label
==
train_labels
[
i
]
train_sets
=
None
# Verify validation set.
validation_sets
=
dataset
.
validation_sets
()
assert
len
(
validation_sets
)
==
2
for
validation_set
in
validation_sets
:
assert
len
(
validation_set
)
==
1000
assert
isinstance
(
train_set
,
gb
.
ItemSetDict
)
for
i
,
item
in
enumerate
(
validation_set
):
assert
isinstance
(
item
,
dict
)
assert
len
(
item
)
==
1
key
=
list
(
item
.
keys
())[
0
]
assert
key
in
[
"paper"
,
"author"
]
id
,
label
=
item
[
key
]
assert
id
==
validation_ids
[
i
]
assert
label
==
validation_labels
[
i
]
validation_sets
=
None
# Verify test set.
test_sets
=
dataset
.
test_sets
()
assert
len
(
test_sets
)
==
2
for
test_set
in
test_sets
:
assert
len
(
test_set
)
==
1000
assert
isinstance
(
train_set
,
gb
.
ItemSetDict
)
for
i
,
item
in
enumerate
(
test_set
):
assert
isinstance
(
item
,
dict
)
assert
len
(
item
)
==
1
key
=
list
(
item
.
keys
())[
0
]
assert
key
in
[
"paper"
,
"author"
]
id
,
label
=
item
[
key
]
assert
id
==
test_ids
[
i
]
assert
label
==
test_labels
[
i
]
test_sets
=
None
dataset
=
None
def
test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label
():
"""Test TVTSet which returns ItemSetDict with node pairs and labels."""
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
train_pairs
=
(
np
.
arange
(
1000
),
np
.
arange
(
1000
,
2000
))
train_labels
=
np
.
random
.
randint
(
0
,
10
,
size
=
1000
)
train_data
=
np
.
vstack
([
train_pairs
,
train_labels
]).
T
train_path
=
os
.
path
.
join
(
test_dir
,
"train.npy"
)
np
.
save
(
train_path
,
train_data
)
validation_pairs
=
(
np
.
arange
(
1000
,
2000
),
np
.
arange
(
2000
,
3000
))
validation_labels
=
np
.
random
.
randint
(
0
,
10
,
size
=
1000
)
validation_data
=
np
.
vstack
([
validation_pairs
,
validation_labels
]).
T
validation_path
=
os
.
path
.
join
(
test_dir
,
"validation.npy"
)
np
.
save
(
validation_path
,
validation_data
)
test_pairs
=
(
np
.
arange
(
2000
,
3000
),
np
.
arange
(
3000
,
4000
))
test_labels
=
np
.
random
.
randint
(
0
,
10
,
size
=
1000
)
test_data
=
np
.
vstack
([
test_pairs
,
test_labels
]).
T
test_path
=
os
.
path
.
join
(
test_dir
,
"test.npy"
)
np
.
save
(
test_path
,
test_data
)
yaml_content
=
f
"""
train_sets:
- - type_name: paper
format: numpy
in_memory: true
path:
{
train_path
}
- - type_name: author
format: numpy
path:
{
train_path
}
validation_sets:
- - type_name: paper
format: numpy
path:
{
validation_path
}
- - type_name: author
format: numpy
path:
{
validation_path
}
test_sets:
- - type_name: paper
format: numpy
in_memory: false
path:
{
test_path
}
- - type_name: author
format: numpy
path:
{
test_path
}
"""
yaml_file
=
os
.
path
.
join
(
test_dir
,
"test.yaml"
)
with
open
(
yaml_file
,
"w"
)
as
f
:
f
.
write
(
yaml_content
)
dataset
=
gb
.
OnDiskDataset
(
yaml_file
)
# Verify train set.
train_sets
=
dataset
.
train_sets
()
assert
len
(
train_sets
)
==
2
for
train_set
in
train_sets
:
assert
len
(
train_set
)
==
1000
assert
isinstance
(
train_set
,
gb
.
ItemSetDict
)
for
i
,
item
in
enumerate
(
train_set
):
assert
isinstance
(
item
,
dict
)
assert
len
(
item
)
==
1
key
=
list
(
item
.
keys
())[
0
]
assert
key
in
[
"paper"
,
"author"
]
src
,
dst
,
label
=
item
[
key
]
assert
src
==
train_pairs
[
0
][
i
]
assert
dst
==
train_pairs
[
1
][
i
]
assert
label
==
train_labels
[
i
]
train_sets
=
None
# Verify validation set.
validation_sets
=
dataset
.
validation_sets
()
assert
len
(
validation_sets
)
==
2
for
validation_set
in
validation_sets
:
assert
len
(
validation_set
)
==
1000
assert
isinstance
(
train_set
,
gb
.
ItemSetDict
)
for
i
,
item
in
enumerate
(
validation_set
):
assert
isinstance
(
item
,
dict
)
assert
len
(
item
)
==
1
key
=
list
(
item
.
keys
())[
0
]
assert
key
in
[
"paper"
,
"author"
]
src
,
dst
,
label
=
item
[
key
]
assert
src
==
validation_pairs
[
0
][
i
]
assert
dst
==
validation_pairs
[
1
][
i
]
assert
label
==
validation_labels
[
i
]
validation_sets
=
None
# Verify test set.
test_sets
=
dataset
.
test_sets
()
assert
len
(
test_sets
)
==
2
for
test_set
in
test_sets
:
assert
len
(
test_set
)
==
1000
assert
isinstance
(
train_set
,
gb
.
ItemSetDict
)
for
i
,
item
in
enumerate
(
test_set
):
assert
isinstance
(
item
,
dict
)
assert
len
(
item
)
==
1
key
=
list
(
item
.
keys
())[
0
]
assert
key
in
[
"paper"
,
"author"
]
src
,
dst
,
label
=
item
[
key
]
assert
src
==
test_pairs
[
0
][
i
]
assert
dst
==
test_pairs
[
1
][
i
]
assert
label
==
test_labels
[
i
]
test_sets
=
None
dataset
=
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment