Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
7815fe8a
Unverified
Commit
7815fe8a
authored
Mar 25, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Mar 26, 2024
Browse files
[CUDA] Make sanity check optional for `dgl.create_block`. (#7240)
parent
3c391533
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
30 additions
and
8 deletions
+30
-8
python/dgl/convert.py
python/dgl/convert.py
+14
-3
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+1
-0
python/dgl/utils/data.py
python/dgl/utils/data.py
+15
-5
No files found.
python/dgl/convert.py
View file @
7815fe8a
...
...
@@ -387,7 +387,12 @@ def heterograph(data_dict, num_nodes_dict=None, idtype=None, device=None):
def
create_block
(
data_dict
,
num_src_nodes
=
None
,
num_dst_nodes
=
None
,
idtype
=
None
,
device
=
None
data_dict
,
num_src_nodes
=
None
,
num_dst_nodes
=
None
,
idtype
=
None
,
device
=
None
,
node_count_check
=
True
,
):
"""Create a message flow graph (MFG) as a :class:`DGLBlock` object.
...
...
@@ -456,6 +461,9 @@ def create_block(
the :attr:`data` argument. If :attr:`data` is not a tuple of node-tensors, the
returned graph is on CPU. If the specified :attr:`device` differs from that of the
provided tensors, it casts the given tensors to the specified device first.
node_count_check : bool, optional
When num_src_nodes and num_dst_nodes are passed, whether we should perform
sanity checks to ensure they are valid.
Returns
-------
...
...
@@ -540,13 +548,16 @@ def create_block(
node_tensor_dict
=
{}
for
(
sty
,
ety
,
dty
),
data
in
data_dict
.
items
():
(
sparse_fmt
,
arrays
),
urange
,
vrange
=
utils
.
graphdata2tensors
(
data
,
idtype
,
bipartite
=
True
data
,
idtype
,
bipartite
=
True
,
infer_node_count
=
need_infer
or
node_count_check
,
)
node_tensor_dict
[(
sty
,
ety
,
dty
)]
=
(
sparse_fmt
,
arrays
)
if
need_infer
:
num_src_nodes
[
sty
]
=
max
(
num_src_nodes
[
sty
],
urange
)
num_dst_nodes
[
dty
]
=
max
(
num_dst_nodes
[
dty
],
vrange
)
el
se
:
# sanity check
el
if
node_count_check
:
# sanity check
if
num_src_nodes
[
sty
]
<
urange
:
raise
DGLError
(
"The given number of nodes of source node type {} must be larger"
...
...
python/dgl/graphbolt/minibatch.py
View file @
7815fe8a
...
...
@@ -303,6 +303,7 @@ class MiniBatch:
sampled_csc
,
num_src_nodes
=
num_src_nodes
,
num_dst_nodes
=
num_dst_nodes
,
node_count_check
=
False
,
)
)
...
...
python/dgl/utils/data.py
View file @
7815fe8a
...
...
@@ -116,7 +116,9 @@ def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None):
SparseAdjTuple
=
namedtuple
(
"SparseAdjTuple"
,
[
"format"
,
"arrays"
])
def
graphdata2tensors
(
data
,
idtype
=
None
,
bipartite
=
False
,
**
kwargs
):
def
graphdata2tensors
(
data
,
idtype
=
None
,
bipartite
=
False
,
infer_node_count
=
True
,
**
kwargs
):
"""Function to convert various types of data to edge tensors and infer
the number of nodes.
...
...
@@ -137,6 +139,9 @@ def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs):
bipartite : bool, optional
Whether infer number of nodes of a bipartite graph --
num_src and num_dst can be different.
infer_node_count : bool, optional
Whether infer number of nodes at all. If False, num_src and num_dst
are returned as None.
kwargs
- edge_id_attr_name : The name (str) of the edge attribute that stores the edge
...
...
@@ -186,23 +191,28 @@ def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs):
data
.
format
,
tuple
(
F
.
tensor
(
a
)
for
a
in
data
.
arrays
)
)
num_src
,
num_dst
=
None
,
None
if
isinstance
(
data
,
SparseAdjTuple
):
if
idtype
is
not
None
:
data
=
SparseAdjTuple
(
data
.
format
,
tuple
(
F
.
astype
(
a
,
idtype
)
for
a
in
data
.
arrays
)
)
num_src
,
num_dst
=
infer_num_nodes
(
data
,
bipartite
=
bipartite
)
if
infer_node_count
:
num_src
,
num_dst
=
infer_num_nodes
(
data
,
bipartite
=
bipartite
)
elif
isinstance
(
data
,
list
):
src
,
dst
=
elist2tensor
(
data
,
idtype
)
data
=
SparseAdjTuple
(
"coo"
,
(
src
,
dst
))
num_src
,
num_dst
=
infer_num_nodes
(
data
,
bipartite
=
bipartite
)
if
infer_node_count
:
num_src
,
num_dst
=
infer_num_nodes
(
data
,
bipartite
=
bipartite
)
elif
isinstance
(
data
,
sp
.
sparse
.
spmatrix
):
# We can get scipy matrix's number of rows and columns easily.
num_src
,
num_dst
=
infer_num_nodes
(
data
,
bipartite
=
bipartite
)
if
infer_node_count
:
num_src
,
num_dst
=
infer_num_nodes
(
data
,
bipartite
=
bipartite
)
data
=
scipy2tensor
(
data
,
idtype
)
elif
isinstance
(
data
,
nx
.
Graph
):
# We can get networkx graph's number of sources and destinations easily.
num_src
,
num_dst
=
infer_num_nodes
(
data
,
bipartite
=
bipartite
)
if
infer_node_count
:
num_src
,
num_dst
=
infer_num_nodes
(
data
,
bipartite
=
bipartite
)
edge_id_attr_name
=
kwargs
.
get
(
"edge_id_attr_name"
,
None
)
if
bipartite
:
top_map
=
kwargs
.
get
(
"top_map"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment