Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
3df6e301
"vscode:/vscode.git/clone" did not exist on "579a41d5231990af604da9f26328052e55996b82"
Unverified
Commit
3df6e301
authored
Apr 02, 2024
by
Mingbang Wang
Committed by
GitHub
Apr 02, 2024
Browse files
[GraphBolt] Update `__len__()` and `__getitem__()` of `ItemSet` (#7253)
parent
8e3d8101
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
37 deletions
+54
-37
python/dgl/graphbolt/itemset.py
python/dgl/graphbolt/itemset.py
+46
-29
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+8
-8
No files found.
python/dgl/graphbolt/itemset.py
View file @
3df6e301
"""GraphBolt Itemset."""
import
textwrap
from
typing
import
Dict
,
Iterable
,
Iterator
,
Sized
,
Tuple
,
Union
from
typing
import
Dict
,
Iterable
,
Iterator
,
Tuple
,
Union
import
torch
...
...
@@ -119,20 +119,35 @@ class ItemSet:
items
:
Union
[
int
,
torch
.
Tensor
,
Iterable
,
Tuple
[
Iterable
]],
names
:
Union
[
str
,
Tuple
[
str
]]
=
None
,
)
->
None
:
if
isinstance
(
items
,
tuple
)
or
is_scalar
(
items
):
if
is_scalar
(
items
):
self
.
_length
=
int
(
items
)
self
.
_items
=
items
self
.
_num_items
=
1
elif
isinstance
(
items
,
tuple
):
try
:
self
.
_length
=
len
(
items
[
0
])
except
TypeError
:
self
.
_length
=
None
if
self
.
_length
is
not
None
:
if
any
(
self
.
_length
!=
len
(
item
)
for
item
in
items
):
raise
ValueError
(
"Size mismatch between items."
)
self
.
_items
=
items
self
.
_num_items
=
len
(
items
)
else
:
try
:
self
.
_length
=
len
(
items
)
except
TypeError
:
self
.
_length
=
None
self
.
_items
=
(
items
,)
self
.
_num_items
=
1
if
names
is
not
None
:
num_items
=
(
len
(
self
.
_items
)
if
isinstance
(
self
.
_items
,
tuple
)
else
1
)
if
isinstance
(
names
,
tuple
):
self
.
_names
=
names
else
:
self
.
_names
=
(
names
,)
assert
num_items
==
len
(
self
.
_names
),
(
f
"Number of items (
{
num_items
}
) and "
assert
self
.
_
num_items
==
len
(
self
.
_names
),
(
f
"Number of items (
{
self
.
_
num_items
}
) and "
f
"names (
{
len
(
self
.
_names
)
}
) must match."
)
else
:
...
...
@@ -144,12 +159,11 @@ class ItemSet:
yield
from
torch
.
arange
(
self
.
_items
,
dtype
=
dtype
)
return
if
len
(
self
.
_items
)
==
1
:
if
self
.
_
num_
items
==
1
:
yield
from
self
.
_items
[
0
]
return
if
isinstance
(
self
.
_items
[
0
],
Sized
):
items_len
=
len
(
self
.
_items
[
0
])
if
self
.
_length
is
not
None
:
# Use for-loop to iterate over the items. It can avoid a long
# waiting time when the items are torch tensors. Since torch
# tensors need to call self.unbind(0) to slice themselves.
...
...
@@ -157,7 +171,7 @@ class ItemSet:
# wait times during the loading phase, and the impact on overall
# performance during the training/testing stage is minimal.
# For more details, see https://github.com/dmlc/dgl/pull/6293.
for
i
in
range
(
items
_len
):
for
i
in
range
(
self
.
_len
gth
):
yield
tuple
(
item
[
i
]
for
item
in
self
.
_items
)
else
:
# If the items are not Sized, we use zip to iterate over them.
...
...
@@ -165,31 +179,20 @@ class ItemSet:
for
item
in
zip_items
:
yield
tuple
(
item
)
def
__len__
(
self
)
->
int
:
if
is_scalar
(
self
.
_items
):
return
int
(
self
.
_items
)
if
isinstance
(
self
.
_items
[
0
],
Sized
):
return
len
(
self
.
_items
[
0
])
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
instance doesn't have valid length."
)
def
__getitem__
(
self
,
idx
:
Union
[
int
,
slice
,
Iterable
])
->
Tuple
:
try
:
len
(
self
)
except
TypeError
:
if
self
.
_length
is
None
:
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
instance doesn't support indexing."
)
if
is_scalar
(
self
.
_items
):
if
isinstance
(
idx
,
slice
):
start
,
stop
,
step
=
idx
.
indices
(
int
(
self
.
_
items
)
)
start
,
stop
,
step
=
idx
.
indices
(
self
.
_
length
)
dtype
=
getattr
(
self
.
_items
,
"dtype"
,
torch
.
int64
)
return
torch
.
arange
(
start
,
stop
,
step
,
dtype
=
dtype
)
if
isinstance
(
idx
,
int
):
if
idx
<
0
:
idx
+=
self
.
_
items
if
idx
<
0
or
idx
>=
self
.
_
items
:
idx
+=
self
.
_
length
if
idx
<
0
or
idx
>=
self
.
_
length
:
raise
IndexError
(
f
"
{
type
(
self
).
__name__
}
index out of range."
)
...
...
@@ -201,7 +204,7 @@ class ItemSet:
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
indices must be integer or slice."
)
if
len
(
self
.
_items
)
==
1
:
if
self
.
_
num_
items
==
1
:
return
self
.
_items
[
0
][
idx
]
return
tuple
(
item
[
idx
]
for
item
in
self
.
_items
)
...
...
@@ -210,6 +213,18 @@ class ItemSet:
"""Return the names of the items."""
return
self
.
_names
@
property
def
num_items
(
self
)
->
int
:
"""Return the number of the items."""
return
self
.
_num_items
def
__len__
(
self
):
if
self
.
_length
is
None
:
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
instance doesn't have valid length."
)
return
self
.
_length
def
__repr__
(
self
)
->
str
:
ret
=
(
f
"
{
self
.
__class__
.
__name__
}
(
\n
"
...
...
@@ -364,8 +379,10 @@ class ItemSetDict:
if
stop
<=
self
.
_offsets
[
offset_idx
]:
break
return
data
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
indices must be int or slice."
)
else
:
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
indices must be int or slice."
)
@
property
def
names
(
self
)
->
Tuple
[
str
]:
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
3df6e301
...
...
@@ -529,7 +529,7 @@ def test_SubgraphSampler_Random_Hetero_Graph_seed_ndoes(sampler_type, replace):
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_without_ded
p
ulication_Homo_seed_nodes
(
sampler_type
):
def
test_SubgraphSampler_without_dedu
p
lication_Homo_seed_nodes
(
sampler_type
):
_check_sampler_type
(
sampler_type
)
graph
=
dgl
.
graph
(
([
5
,
0
,
1
,
5
,
6
,
7
,
2
,
2
,
4
],
[
0
,
1
,
2
,
2
,
2
,
2
,
3
,
4
,
4
])
...
...
@@ -643,7 +643,7 @@ def _assert_homo_values(
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_without_ded
p
ulication_Hetero_seed_nodes
(
sampler_type
):
def
test_SubgraphSampler_without_dedu
p
lication_Hetero_seed_nodes
(
sampler_type
):
_check_sampler_type
(
sampler_type
)
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
items
=
torch
.
arange
(
2
)
...
...
@@ -1409,7 +1409,7 @@ def test_SubgraphSampler_Random_Hetero_Graph(sampler_type, replace):
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_without_ded
p
ulication_Homo_Node
(
sampler_type
):
def
test_SubgraphSampler_without_dedu
p
lication_Homo_Node
(
sampler_type
):
_check_sampler_type
(
sampler_type
)
graph
=
dgl
.
graph
(
([
5
,
0
,
1
,
5
,
6
,
7
,
2
,
2
,
4
],
[
0
,
1
,
2
,
2
,
2
,
2
,
3
,
4
,
4
])
...
...
@@ -1473,7 +1473,7 @@ def test_SubgraphSampler_without_dedpulication_Homo_Node(sampler_type):
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_without_ded
p
ulication_Hetero_Node
(
sampler_type
):
def
test_SubgraphSampler_without_dedu
p
lication_Hetero_Node
(
sampler_type
):
_check_sampler_type
(
sampler_type
)
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
items
=
torch
.
arange
(
2
)
...
...
@@ -1829,7 +1829,7 @@ def test_SubgraphSampler_Hetero_multifanout_per_layer(sampler_type):
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_without_ded
p
ulication_Homo_Link
(
sampler_type
):
def
test_SubgraphSampler_without_dedu
p
lication_Homo_Link
(
sampler_type
):
_check_sampler_type
(
sampler_type
)
graph
=
dgl
.
graph
(
([
5
,
0
,
1
,
5
,
6
,
7
,
2
,
2
,
4
],
[
0
,
1
,
2
,
2
,
2
,
2
,
3
,
4
,
4
])
...
...
@@ -1845,7 +1845,7 @@ def test_SubgraphSampler_without_dedpulication_Homo_Link(sampler_type):
graph
.
edge_attributes
=
{
"timestamp"
:
torch
.
zeros
(
graph
.
indices
.
numel
()).
to
(
F
.
ctx
())
}
items
=
(
items
,
torch
.
randint
(
1
,
10
,
(
3
,)))
items
=
(
items
,
torch
.
randint
(
1
,
10
,
(
2
,)))
names
=
(
names
,
"timestamp"
)
itemset
=
gb
.
ItemSet
(
items
,
names
=
names
)
...
...
@@ -1891,7 +1891,7 @@ def test_SubgraphSampler_without_dedpulication_Homo_Link(sampler_type):
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_without_ded
p
ulication_Hetero_Link
(
sampler_type
):
def
test_SubgraphSampler_without_dedu
p
lication_Hetero_Link
(
sampler_type
):
_check_sampler_type
(
sampler_type
)
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
items
=
torch
.
arange
(
2
).
view
(
1
,
2
)
...
...
@@ -1903,7 +1903,7 @@ def test_SubgraphSampler_without_dedpulication_Hetero_Link(sampler_type):
graph
.
edge_attributes
=
{
"timestamp"
:
torch
.
zeros
(
graph
.
indices
.
numel
()).
to
(
F
.
ctx
())
}
items
=
(
items
,
torch
.
randint
(
1
,
10
,
(
2
,)))
items
=
(
items
,
torch
.
randint
(
1
,
10
,
(
1
,)))
names
=
(
names
,
"timestamp"
)
itemset
=
gb
.
ItemSetDict
({
"n1:e1:n2"
:
gb
.
ItemSet
(
items
,
names
=
names
)})
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment