Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
00edb949
Unverified
Commit
00edb949
authored
Nov 24, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Nov 24, 2020
Browse files
[Performance] Accelerate batching (#2363)
* speed up batching * more fix * lint * fix
parent
58775ada
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
56 deletions
+56
-56
python/dgl/batch.py
python/dgl/batch.py
+25
-15
python/dgl/utils/checks.py
python/dgl/utils/checks.py
+31
-41
No files found.
python/dgl/batch.py
View file @
00edb949
...
@@ -168,7 +168,9 @@ def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None):
...
@@ -168,7 +168,9 @@ def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None):
raise
DGLError
(
"Batching a block is not supported."
)
raise
DGLError
(
"Batching a block is not supported."
)
relations
=
list
(
sorted
(
graphs
[
0
].
canonical_etypes
))
relations
=
list
(
sorted
(
graphs
[
0
].
canonical_etypes
))
relation_ids
=
[
graphs
[
0
].
get_etype_id
(
r
)
for
r
in
relations
]
ntypes
=
list
(
sorted
(
graphs
[
0
].
ntypes
))
ntypes
=
list
(
sorted
(
graphs
[
0
].
ntypes
))
ntype_ids
=
[
graphs
[
0
].
get_ntype_id
(
n
)
for
n
in
ntypes
]
etypes
=
[
etype
for
_
,
etype
,
_
in
relations
]
etypes
=
[
etype
for
_
,
etype
,
_
in
relations
]
gidx
=
disjoint_union
(
graphs
[
0
].
_graph
.
metagraph
,
[
g
.
_graph
for
g
in
graphs
])
gidx
=
disjoint_union
(
graphs
[
0
].
_graph
.
metagraph
,
[
g
.
_graph
for
g
in
graphs
])
...
@@ -188,27 +190,35 @@ def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None):
...
@@ -188,27 +190,35 @@ def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None):
# Batch node feature
# Batch node feature
if
ndata
is
not
None
:
if
ndata
is
not
None
:
for
ntype
in
ntypes
:
for
ntype_id
,
ntype
in
zip
(
ntype_ids
,
ntypes
):
feat_dicts
=
[
g
.
nodes
[
ntype
].
data
for
g
in
graphs
if
g
.
number_of_nodes
(
ntype
)
>
0
]
frames
=
[
ret_feat
=
_batch_feat_dicts
(
feat_dicts
,
ndata
,
'nodes["{}"].data'
.
format
(
ntype
))
g
.
_node_frames
[
ntype_id
]
for
g
in
graphs
if
g
.
_graph
.
number_of_nodes
(
ntype_id
)
>
0
]
# TODO: do we require graphs with no nodes/edges to have the same schema? Currently
# we allow empty graphs to have no features during batching.
ret_feat
=
_batch_feat_dicts
(
frames
,
ndata
,
'nodes["{}"].data'
.
format
(
ntype
))
retg
.
nodes
[
ntype
].
data
.
update
(
ret_feat
)
retg
.
nodes
[
ntype
].
data
.
update
(
ret_feat
)
# Batch edge feature
# Batch edge feature
if
edata
is
not
None
:
if
edata
is
not
None
:
for
etype
in
relations
:
for
etype_id
,
etype
in
zip
(
relation_ids
,
relations
):
feat_dicts
=
[
g
.
edges
[
etype
].
data
for
g
in
graphs
if
g
.
number_of_edges
(
etype
)
>
0
]
frames
=
[
ret_feat
=
_batch_feat_dicts
(
feat_dicts
,
edata
,
'edges[{}].data'
.
format
(
etype
))
g
.
_edge_frames
[
etype_id
]
for
g
in
graphs
if
g
.
_graph
.
number_of_edges
(
etype_id
)
>
0
]
# TODO: do we require graphs with no nodes/edges to have the same schema? Currently
# we allow empty graphs to have no features during batching.
ret_feat
=
_batch_feat_dicts
(
frames
,
edata
,
'edges[{}].data'
.
format
(
etype
))
retg
.
edges
[
etype
].
data
.
update
(
ret_feat
)
retg
.
edges
[
etype
].
data
.
update
(
ret_feat
)
return
retg
return
retg
def
_batch_feat_dicts
(
f
eat_dict
s
,
keys
,
feat_dict_name
):
def
_batch_feat_dicts
(
f
rame
s
,
keys
,
feat_dict_name
):
"""Internal function to batch feature dictionaries.
"""Internal function to batch feature dictionaries.
Parameters
Parameters
----------
----------
f
eat_dict
s : list[
dict[str, Tensor]
]
f
rame
s : list[
Frame
]
Feature dictionary list.
List of frames
keys : list[str]
keys : list[str]
Feature keys. Can be '__ALL__', meaning batching all features.
Feature keys. Can be '__ALL__', meaning batching all features.
feat_dict_name : str
feat_dict_name : str
...
@@ -219,17 +229,17 @@ def _batch_feat_dicts(feat_dicts, keys, feat_dict_name):
...
@@ -219,17 +229,17 @@ def _batch_feat_dicts(feat_dicts, keys, feat_dict_name):
dict[str, Tensor]
dict[str, Tensor]
New feature dict.
New feature dict.
"""
"""
if
len
(
f
eat_dict
s
)
==
0
:
if
len
(
f
rame
s
)
==
0
:
return
{}
return
{}
schemas
=
[
frame
.
schemes
for
frame
in
frames
]
# sanity checks
# sanity checks
if
is_all
(
keys
):
if
is_all
(
keys
):
utils
.
check_all_same_
keys
(
feat_dict
s
,
feat_dict_name
)
utils
.
check_all_same_
schema
(
schema
s
,
feat_dict_name
)
keys
=
feat_dict
s
[
0
].
keys
()
keys
=
schema
s
[
0
].
keys
()
else
:
else
:
utils
.
check_all_have_keys
(
feat_dicts
,
keys
,
feat_dict_name
)
utils
.
check_all_same_schema_for_keys
(
schemas
,
keys
,
feat_dict_name
)
utils
.
check_all_same_schema
(
feat_dicts
,
keys
,
feat_dict_name
)
# concat features
# concat features
ret_feat
=
{
k
:
F
.
cat
([
fd
[
k
]
for
fd
in
f
eat_dict
s
],
0
)
for
k
in
keys
}
ret_feat
=
{
k
:
F
.
cat
([
fd
[
k
]
for
fd
in
f
rame
s
],
0
)
for
k
in
keys
}
return
ret_feat
return
ret_feat
def
unbatch
(
g
,
node_split
=
None
,
edge_split
=
None
):
def
unbatch
(
g
,
node_split
=
None
,
edge_split
=
None
):
...
...
python/dgl/utils/checks.py
View file @
00edb949
...
@@ -119,51 +119,41 @@ def check_all_same_device(glist, name):
...
@@ -119,51 +119,41 @@ def check_all_same_device(glist, name):
raise
DGLError
(
'Expect {}[{}] to be on device {}, but got {}.'
.
format
(
raise
DGLError
(
'Expect {}[{}] to be on device {}, but got {}.'
.
format
(
name
,
i
,
device
,
g
.
device
))
name
,
i
,
device
,
g
.
device
))
def
check_all_same_
keys
(
dict_list
,
name
):
def
check_all_same_
schema
(
schemas
,
name
):
"""Check
all the dictionaries have the same set of keys
."""
"""Check
the list of schemas are the same
."""
if
len
(
dict_list
)
==
0
:
if
len
(
schemas
)
==
0
:
return
return
keys
=
dict_list
[
0
].
keys
()
for
dct
in
dict_list
:
if
keys
!=
dct
.
keys
():
raise
DGLError
(
'Expect all {} to have the same set of keys, but got'
' {} and {}.'
.
format
(
name
,
keys
,
dct
.
keys
()))
def
check_all_have_keys
(
dict_list
,
keys
,
name
):
"""Check the dictionaries all have the given keys."""
if
len
(
dict_list
)
==
0
:
return
keys
=
set
(
keys
)
for
dct
in
dict_list
:
if
not
keys
.
issubset
(
dct
.
keys
()):
raise
DGLError
(
'Expect all {} to include keys {}, but got {}.'
.
format
(
name
,
keys
,
dct
.
keys
()))
def
check_all_same_schema
(
feat_dict_list
,
keys
,
name
):
"""Check the features of the given keys all have the same schema.
Suggest calling ``check_all_have_keys`` first.
for
i
,
schema
in
enumerate
(
schemas
):
if
schema
!=
schemas
[
0
]:
raise
DGLError
(
'Expect all graphs to have the same schema on {}, '
'but graph {} got
\n\t
{}
\n
which is different from
\n\t
{}.'
.
format
(
name
,
i
,
schema
,
schemas
[
0
]))
Parameters
def
check_all_same_schema_for_keys
(
schemas
,
keys
,
name
):
----------
"""Check the list of schemas are the same on the given keys."""
feat_dict_list : list[dict[str, Tensor]]
if
len
(
schemas
)
==
0
:
Feature dictionaries.
keys : list[str]
Keys
name : str
Name of this feature dict.
"""
if
len
(
feat_dict_list
)
==
0
:
return
return
for
fdict
in
feat_dict_list
:
for
k
in
keys
:
head
=
None
t1
=
feat_dict_list
[
0
][
k
]
keys
=
set
(
keys
)
t2
=
fdict
[
k
]
for
i
,
schema
in
enumerate
(
schemas
):
if
F
.
dtype
(
t1
)
!=
F
.
dtype
(
t2
)
or
F
.
shape
(
t1
)[
1
:]
!=
F
.
shape
(
t2
)[
1
:]:
if
not
keys
.
issubset
(
schema
.
keys
()):
raise
DGLError
(
'Expect all features {}["{}"] to have the same data type'
raise
DGLError
(
' and feature size, but got
\n\t
{} {}
\n
and
\n\t
{} {}.'
.
format
(
'Expect all graphs to have keys {} on {}, '
name
,
k
,
F
.
dtype
(
t1
),
F
.
shape
(
t1
)[
1
:],
'but graph {} got keys {}.'
.
format
(
F
.
dtype
(
t2
),
F
.
shape
(
t2
)[
1
:]))
keys
,
name
,
i
,
schema
.
keys
()))
if
head
is
None
:
head
=
{
k
:
schema
[
k
]
for
k
in
keys
}
else
:
target
=
{
k
:
schema
[
k
]
for
k
in
keys
}
if
target
!=
head
:
raise
DGLError
(
'Expect all graphs to have the same schema for keys {} on {}, '
'but graph {} got
\n\t
{}
\n
which is different from
\n\t
{}.'
.
format
(
keys
,
name
,
i
,
target
,
head
))
def
check_valid_idtype
(
idtype
):
def
check_valid_idtype
(
idtype
):
"""Check whether the value of the idtype argument is valid (int32/int64)
"""Check whether the value of the idtype argument is valid (int32/int64)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment