Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
7815fe8a
Unverified
Commit
7815fe8a
authored
Mar 25, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Mar 26, 2024
Browse files
[CUDA] Make sanity check optional for `dgl.create_block`. (#7240)
parent
3c391533
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
30 additions
and
8 deletions
+30
-8
python/dgl/convert.py
python/dgl/convert.py
+14
-3
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+1
-0
python/dgl/utils/data.py
python/dgl/utils/data.py
+15
-5
No files found.
python/dgl/convert.py
View file @
7815fe8a
...
@@ -387,7 +387,12 @@ def heterograph(data_dict, num_nodes_dict=None, idtype=None, device=None):
...
@@ -387,7 +387,12 @@ def heterograph(data_dict, num_nodes_dict=None, idtype=None, device=None):
def
create_block
(
def
create_block
(
data_dict
,
num_src_nodes
=
None
,
num_dst_nodes
=
None
,
idtype
=
None
,
device
=
None
data_dict
,
num_src_nodes
=
None
,
num_dst_nodes
=
None
,
idtype
=
None
,
device
=
None
,
node_count_check
=
True
,
):
):
"""Create a message flow graph (MFG) as a :class:`DGLBlock` object.
"""Create a message flow graph (MFG) as a :class:`DGLBlock` object.
...
@@ -456,6 +461,9 @@ def create_block(
...
@@ -456,6 +461,9 @@ def create_block(
the :attr:`data` argument. If :attr:`data` is not a tuple of node-tensors, the
the :attr:`data` argument. If :attr:`data` is not a tuple of node-tensors, the
returned graph is on CPU. If the specified :attr:`device` differs from that of the
returned graph is on CPU. If the specified :attr:`device` differs from that of the
provided tensors, it casts the given tensors to the specified device first.
provided tensors, it casts the given tensors to the specified device first.
node_count_check : bool, optional
When num_src_nodes and num_dst_nodes are passed, whether we should perform
sanity checks to ensure they are valid.
Returns
Returns
-------
-------
...
@@ -540,13 +548,16 @@ def create_block(
...
@@ -540,13 +548,16 @@ def create_block(
node_tensor_dict
=
{}
node_tensor_dict
=
{}
for
(
sty
,
ety
,
dty
),
data
in
data_dict
.
items
():
for
(
sty
,
ety
,
dty
),
data
in
data_dict
.
items
():
(
sparse_fmt
,
arrays
),
urange
,
vrange
=
utils
.
graphdata2tensors
(
(
sparse_fmt
,
arrays
),
urange
,
vrange
=
utils
.
graphdata2tensors
(
data
,
idtype
,
bipartite
=
True
data
,
idtype
,
bipartite
=
True
,
infer_node_count
=
need_infer
or
node_count_check
,
)
)
node_tensor_dict
[(
sty
,
ety
,
dty
)]
=
(
sparse_fmt
,
arrays
)
node_tensor_dict
[(
sty
,
ety
,
dty
)]
=
(
sparse_fmt
,
arrays
)
if
need_infer
:
if
need_infer
:
num_src_nodes
[
sty
]
=
max
(
num_src_nodes
[
sty
],
urange
)
num_src_nodes
[
sty
]
=
max
(
num_src_nodes
[
sty
],
urange
)
num_dst_nodes
[
dty
]
=
max
(
num_dst_nodes
[
dty
],
vrange
)
num_dst_nodes
[
dty
]
=
max
(
num_dst_nodes
[
dty
],
vrange
)
el
se
:
# sanity check
el
if
node_count_check
:
# sanity check
if
num_src_nodes
[
sty
]
<
urange
:
if
num_src_nodes
[
sty
]
<
urange
:
raise
DGLError
(
raise
DGLError
(
"The given number of nodes of source node type {} must be larger"
"The given number of nodes of source node type {} must be larger"
...
...
python/dgl/graphbolt/minibatch.py
View file @
7815fe8a
...
@@ -303,6 +303,7 @@ class MiniBatch:
...
@@ -303,6 +303,7 @@ class MiniBatch:
sampled_csc
,
sampled_csc
,
num_src_nodes
=
num_src_nodes
,
num_src_nodes
=
num_src_nodes
,
num_dst_nodes
=
num_dst_nodes
,
num_dst_nodes
=
num_dst_nodes
,
node_count_check
=
False
,
)
)
)
)
...
...
python/dgl/utils/data.py
View file @
7815fe8a
...
@@ -116,7 +116,9 @@ def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None):
...
@@ -116,7 +116,9 @@ def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None):
SparseAdjTuple
=
namedtuple
(
"SparseAdjTuple"
,
[
"format"
,
"arrays"
])
SparseAdjTuple
=
namedtuple
(
"SparseAdjTuple"
,
[
"format"
,
"arrays"
])
def
graphdata2tensors
(
data
,
idtype
=
None
,
bipartite
=
False
,
**
kwargs
):
def
graphdata2tensors
(
data
,
idtype
=
None
,
bipartite
=
False
,
infer_node_count
=
True
,
**
kwargs
):
"""Function to convert various types of data to edge tensors and infer
"""Function to convert various types of data to edge tensors and infer
the number of nodes.
the number of nodes.
...
@@ -137,6 +139,9 @@ def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs):
...
@@ -137,6 +139,9 @@ def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs):
bipartite : bool, optional
bipartite : bool, optional
Whether infer number of nodes of a bipartite graph --
Whether infer number of nodes of a bipartite graph --
num_src and num_dst can be different.
num_src and num_dst can be different.
infer_node_count : bool, optional
Whether infer number of nodes at all. If False, num_src and num_dst
are returned as None.
kwargs
kwargs
- edge_id_attr_name : The name (str) of the edge attribute that stores the edge
- edge_id_attr_name : The name (str) of the edge attribute that stores the edge
...
@@ -186,23 +191,28 @@ def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs):
...
@@ -186,23 +191,28 @@ def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs):
data
.
format
,
tuple
(
F
.
tensor
(
a
)
for
a
in
data
.
arrays
)
data
.
format
,
tuple
(
F
.
tensor
(
a
)
for
a
in
data
.
arrays
)
)
)
num_src
,
num_dst
=
None
,
None
if
isinstance
(
data
,
SparseAdjTuple
):
if
isinstance
(
data
,
SparseAdjTuple
):
if
idtype
is
not
None
:
if
idtype
is
not
None
:
data
=
SparseAdjTuple
(
data
=
SparseAdjTuple
(
data
.
format
,
tuple
(
F
.
astype
(
a
,
idtype
)
for
a
in
data
.
arrays
)
data
.
format
,
tuple
(
F
.
astype
(
a
,
idtype
)
for
a
in
data
.
arrays
)
)
)
num_src
,
num_dst
=
infer_num_nodes
(
data
,
bipartite
=
bipartite
)
if
infer_node_count
:
num_src
,
num_dst
=
infer_num_nodes
(
data
,
bipartite
=
bipartite
)
elif
isinstance
(
data
,
list
):
elif
isinstance
(
data
,
list
):
src
,
dst
=
elist2tensor
(
data
,
idtype
)
src
,
dst
=
elist2tensor
(
data
,
idtype
)
data
=
SparseAdjTuple
(
"coo"
,
(
src
,
dst
))
data
=
SparseAdjTuple
(
"coo"
,
(
src
,
dst
))
num_src
,
num_dst
=
infer_num_nodes
(
data
,
bipartite
=
bipartite
)
if
infer_node_count
:
num_src
,
num_dst
=
infer_num_nodes
(
data
,
bipartite
=
bipartite
)
elif
isinstance
(
data
,
sp
.
sparse
.
spmatrix
):
elif
isinstance
(
data
,
sp
.
sparse
.
spmatrix
):
# We can get scipy matrix's number of rows and columns easily.
# We can get scipy matrix's number of rows and columns easily.
num_src
,
num_dst
=
infer_num_nodes
(
data
,
bipartite
=
bipartite
)
if
infer_node_count
:
num_src
,
num_dst
=
infer_num_nodes
(
data
,
bipartite
=
bipartite
)
data
=
scipy2tensor
(
data
,
idtype
)
data
=
scipy2tensor
(
data
,
idtype
)
elif
isinstance
(
data
,
nx
.
Graph
):
elif
isinstance
(
data
,
nx
.
Graph
):
# We can get networkx graph's number of sources and destinations easily.
# We can get networkx graph's number of sources and destinations easily.
num_src
,
num_dst
=
infer_num_nodes
(
data
,
bipartite
=
bipartite
)
if
infer_node_count
:
num_src
,
num_dst
=
infer_num_nodes
(
data
,
bipartite
=
bipartite
)
edge_id_attr_name
=
kwargs
.
get
(
"edge_id_attr_name"
,
None
)
edge_id_attr_name
=
kwargs
.
get
(
"edge_id_attr_name"
,
None
)
if
bipartite
:
if
bipartite
:
top_map
=
kwargs
.
get
(
"top_map"
)
top_map
=
kwargs
.
get
(
"top_map"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment