Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
2be55fb5
Commit
2be55fb5
authored
Oct 03, 2018
by
Minjie Wang
Browse files
graph batch and unbatch
parent
7d04c8c9
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
261 additions
and
62 deletions
+261
-62
include/dgl/graph_op.h
include/dgl/graph_op.h
+17
-3
python/dgl/batch.py
python/dgl/batch.py
+40
-29
python/dgl/graph_index.py
python/dgl/graph_index.py
+34
-0
src/graph/graph_apis.cc
src/graph/graph_apis.cc
+36
-0
src/graph/graph_op.cc
src/graph/graph_op.cc
+88
-0
tests/pytorch/test_graph_batch.py
tests/pytorch/test_graph_batch.py
+46
-30
No files found.
include/dgl/graph_op.h
View file @
2be55fb5
...
@@ -31,14 +31,28 @@ class GraphOp {
...
@@ -31,14 +31,28 @@ class GraphOp {
/*!
/*!
* \brief Partition the graph into several subgraphs.
* \brief Partition the graph into several subgraphs.
*
*
* Th
e graph will be partitioned by the node ids. Edges between
partition
s
* Th
is is a reverse operation of DisjointUnion. The graph will be
partition
ed
*
will be ignored
. This requires the given number of partitions to evenly
*
into num graphs
. This requires the given number of partitions to evenly
* divides the number of nodes in the graph.
* divides the number of nodes in the graph.
*
*
* \param graph The graph to be partitioned.
* \param num The number of partitions.
* \param num The number of partitions.
* \return a list of partitioned graphs
* \return a list of partitioned graphs
*/
*/
static
std
::
vector
<
Graph
>
PartitionByNum
(
const
Graph
*
graph
,
size_t
num
);
static
std
::
vector
<
Graph
>
DisjointPartitionByNum
(
const
Graph
*
graph
,
int64_t
num
);
/*!
* \brief Partition the graph into several subgraphs.
*
* This is a reverse operation of DisjointUnion. The graph will be partitioned
* based on the given sizes. This requires the sum of the given sizes is equal
* to the number of nodes in the graph.
*
* \param graph The graph to be partitioned.
* \param sizes The number of partitions.
* \return a list of partitioned graphs
*/
static
std
::
vector
<
Graph
>
DisjointPartitionBySizes
(
const
Graph
*
graph
,
IdArray
sizes
);
};
};
}
// namespace dgl
}
// namespace dgl
...
...
python/dgl/batch.py
View file @
2be55fb5
...
@@ -8,6 +8,7 @@ from .frame import FrameRef
...
@@ -8,6 +8,7 @@ from .frame import FrameRef
from
.graph
import
DGLGraph
from
.graph
import
DGLGraph
from
.
import
graph_index
as
gi
from
.
import
graph_index
as
gi
from
.
import
backend
as
F
from
.
import
backend
as
F
from
.
import
utils
class
BatchedDGLGraph
(
DGLGraph
):
class
BatchedDGLGraph
(
DGLGraph
):
"""The batched DGL graph.
"""The batched DGL graph.
...
@@ -24,7 +25,6 @@ class BatchedDGLGraph(DGLGraph):
...
@@ -24,7 +25,6 @@ class BatchedDGLGraph(DGLGraph):
The edge attributes to also be batched.
The edge attributes to also be batched.
"""
"""
def
__init__
(
self
,
graph_list
,
node_attrs
,
edge_attrs
):
def
__init__
(
self
,
graph_list
,
node_attrs
,
edge_attrs
):
# TODO(minjie): handle the input is again a batched graph.
# create batched graph index
# create batched graph index
batched_index
=
gi
.
disjoint_union
([
g
.
_graph
for
g
in
graph_list
])
batched_index
=
gi
.
disjoint_union
([
g
.
_graph
for
g
in
graph_list
])
# create batched node and edge frames
# create batched node and edge frames
...
@@ -43,9 +43,19 @@ class BatchedDGLGraph(DGLGraph):
...
@@ -43,9 +43,19 @@ class BatchedDGLGraph(DGLGraph):
edge_frame
=
batched_edge_frame
)
edge_frame
=
batched_edge_frame
)
# extra members
# extra members
self
.
_batch_size
=
len
(
graph_list
)
self
.
_batch_size
=
0
self
.
_batch_num_nodes
=
[
gr
.
number_of_nodes
()
for
gr
in
graph_list
]
self
.
_batch_num_nodes
=
[]
self
.
_batch_num_edges
=
[
gr
.
number_of_edges
()
for
gr
in
graph_list
]
self
.
_batch_num_edges
=
[]
for
gr
in
graph_list
:
if
isinstance
(
gr
,
BatchedDGLGraph
):
# handle the input is again a batched graph.
self
.
_batch_size
+=
gr
.
_batch_size
self
.
_batch_num_nodes
+=
gr
.
_batch_num_nodes
self
.
_batch_num_edges
+=
gr
.
_batch_num_edges
else
:
self
.
_batch_size
+=
1
self
.
_batch_num_nodes
.
append
(
gr
.
number_of_nodes
())
self
.
_batch_num_edges
.
append
(
gr
.
number_of_edges
())
@
property
@
property
def
batch_size
(
self
):
def
batch_size
(
self
):
...
@@ -78,10 +88,12 @@ class BatchedDGLGraph(DGLGraph):
...
@@ -78,10 +88,12 @@ class BatchedDGLGraph(DGLGraph):
# new APIs
# new APIs
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
"""Slice the batch and return the batch of graphs specified by the idx."""
"""Slice the batch and return the batch of graphs specified by the idx."""
# TODO
pass
pass
def
__setitem__
(
self
,
idx
,
val
):
def
__setitem__
(
self
,
idx
,
val
):
"""Set the value of the slice. The graph size cannot be changed."""
"""Set the value of the slice. The graph size cannot be changed."""
# TODO
pass
pass
'''
'''
...
@@ -114,37 +126,36 @@ def split(graph_batch, num_or_size_splits):
...
@@ -114,37 +126,36 @@ def split(graph_batch, num_or_size_splits):
# TODO(minjie): could follow torch.split syntax
# TODO(minjie): could follow torch.split syntax
pass
pass
def
unbatch
(
graph
_batch
):
def
unbatch
(
graph
):
"""Unbatch the graph and return a list of subgraphs.
"""Unbatch the graph and return a list of subgraphs.
Parameters
Parameters
----------
----------
graph
_b
atch
:
DGLGraph
graph
: B
atch
ed
DGLGraph
The batched graph.
The batched graph.
"""
"""
assert
False
,
"disabled for now"
assert
isinstance
(
graph
,
BatchedDGLGraph
)
graph_list
=
graph_batch
.
graph_list
bsize
=
graph
.
batch_size
num_graphs
=
len
(
graph_list
)
bn
=
graph
.
batch_num_nodes
# split and set node attrs
be
=
graph
.
batch_num_edges
attrs
=
[{}
for
_
in
range
(
num_graphs
)]
# node attr dict for each graph
pttns
=
gi
.
disjoint_partition
(
graph
.
_graph
,
utils
.
toindex
(
bn
))
for
key
in
graph_batch
.
node_attr_schemes
():
# split the frames
vals
=
F
.
unpack
(
graph_batch
.
pop_n_repr
(
key
),
graph_batch
.
num_nodes
)
node_frames
=
[
FrameRef
()
for
i
in
range
(
bsize
)]
for
attr
,
val
in
zip
(
attrs
,
vals
):
edge_frames
=
[
FrameRef
()
for
i
in
range
(
bsize
)]
attr
[
key
]
=
val
for
attr
,
col
in
graph
.
_node_frame
.
items
():
for
attr
,
g
in
zip
(
attrs
,
graph_list
):
# TODO: device context
g
.
set_n_repr
(
attr
)
col_splits
=
F
.
unpack
(
col
,
bn
)
for
i
in
range
(
bsize
):
# split and set edge attrs
node_frames
[
i
][
attr
]
=
col_splits
[
i
]
attrs
=
[{}
for
_
in
range
(
num_graphs
)]
# edge attr dict for each graph
for
attr
,
col
in
graph
.
_edge_frame
.
items
():
for
key
in
graph_batch
.
edge_attr_schemes
():
# TODO: device context
vals
=
F
.
unpack
(
graph_batch
.
pop_e_repr
(
key
),
graph_batch
.
num_edges
)
col_splits
=
F
.
unpack
(
col
,
be
)
for
attr
,
val
in
zip
(
attrs
,
vals
):
for
i
in
range
(
bsize
):
attr
[
key
]
=
val
edge_frames
[
i
][
attr
]
=
col_splits
[
i
]
for
attr
,
g
in
zip
(
attrs
,
graph_list
):
return
[
DGLGraph
(
graph_data
=
pttns
[
i
],
g
.
set_e_repr
(
attr
)
node_frame
=
node_frames
[
i
],
edge_frame
=
edge_frames
[
i
])
for
i
in
range
(
bsize
)]
return
graph_list
def
batch
(
graph_list
,
node_attrs
=
ALL
,
edge_attrs
=
ALL
):
def
batch
(
graph_list
,
node_attrs
=
ALL
,
edge_attrs
=
ALL
):
"""Batch a list of DGLGraphs into one single graph.
"""Batch a list of DGLGraphs into one single graph.
...
...
python/dgl/graph_index.py
View file @
2be55fb5
...
@@ -483,6 +483,40 @@ def disjoint_union(graphs):
...
@@ -483,6 +483,40 @@ def disjoint_union(graphs):
handle
=
_CAPI_DGLDisjointUnion
(
inputs
,
len
(
graphs
))
handle
=
_CAPI_DGLDisjointUnion
(
inputs
,
len
(
graphs
))
return
GraphIndex
(
handle
)
return
GraphIndex
(
handle
)
def
disjoint_partition
(
graph
,
num_or_size_splits
):
"""Partition the graph disjointly.
This is a reverse operation of DisjointUnion. The graph will be partitioned
into num graphs. This requires the given number of partitions to evenly
divides the number of nodes in the graph. If the a size list is given,
the sum of the given sizes is equal.
Parameters
----------
graph : GraphIndex
The graph to be partitioned
num_or_size_splits : int or utils.Index
The partition number of size splits
Returns
-------
list of GraphIndex
The partitioned graphs
"""
if
isinstance
(
num_or_size_splits
,
utils
.
Index
):
rst
=
_CAPI_DGLDisjointPartitionBySizes
(
graph
.
_handle
,
num_or_size_splits
.
todgltensor
())
else
:
rst
=
_CAPI_DGLDisjointPartitionByNum
(
graph
.
_handle
,
int
(
num_or_size_splits
))
graphs
=
[]
for
val
in
rst
.
asnumpy
():
handle
=
ctypes
.
cast
(
int
(
val
),
ctypes
.
c_void_p
)
graphs
.
append
(
GraphIndex
(
handle
))
return
graphs
def
create_graph_index
(
graph_data
=
None
):
def
create_graph_index
(
graph_data
=
None
):
"""Create a graph index object.
"""Create a graph index object.
...
...
src/graph/graph_apis.cc
View file @
2be55fb5
...
@@ -7,6 +7,7 @@ using tvm::runtime::TVMArgs;
...
@@ -7,6 +7,7 @@ using tvm::runtime::TVMArgs;
using
tvm
::
runtime
::
TVMArgValue
;
using
tvm
::
runtime
::
TVMArgValue
;
using
tvm
::
runtime
::
TVMRetValue
;
using
tvm
::
runtime
::
TVMRetValue
;
using
tvm
::
runtime
::
PackedFunc
;
using
tvm
::
runtime
::
PackedFunc
;
using
tvm
::
runtime
::
NDArray
;
namespace
dgl
{
namespace
dgl
{
...
@@ -289,4 +290,39 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion")
...
@@ -289,4 +290,39 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion")
*
rv
=
ghandle
;
*
rv
=
ghandle
;
});
});
TVM_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLDisjointPartitionByNum"
)
.
set_body
([]
(
TVMArgs
args
,
TVMRetValue
*
rv
)
{
GraphHandle
ghandle
=
args
[
0
];
const
Graph
*
gptr
=
static_cast
<
Graph
*>
(
ghandle
);
int64_t
num
=
args
[
1
];
std
::
vector
<
Graph
>&&
rst
=
GraphOp
::
DisjointPartitionByNum
(
gptr
,
num
);
// return the pointer array as an integer array
const
int64_t
len
=
rst
.
size
();
NDArray
ptr_array
=
NDArray
::
Empty
({
len
},
DLDataType
{
kDLInt
,
64
,
1
},
DLContext
{
kDLCPU
,
0
});
int64_t
*
ptr_array_data
=
static_cast
<
int64_t
*>
(
ptr_array
->
data
);
for
(
size_t
i
=
0
;
i
<
rst
.
size
();
++
i
)
{
Graph
*
ptr
=
new
Graph
();
*
ptr
=
std
::
move
(
rst
[
i
]);
ptr_array_data
[
i
]
=
reinterpret_cast
<
std
::
intptr_t
>
(
ptr
);
}
*
rv
=
ptr_array
;
});
TVM_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLDisjointPartitionBySizes"
)
.
set_body
([]
(
TVMArgs
args
,
TVMRetValue
*
rv
)
{
GraphHandle
ghandle
=
args
[
0
];
const
Graph
*
gptr
=
static_cast
<
Graph
*>
(
ghandle
);
const
IdArray
sizes
=
IdArray
::
FromDLPack
(
CreateTmpDLManagedTensor
(
args
[
1
]));
std
::
vector
<
Graph
>&&
rst
=
GraphOp
::
DisjointPartitionBySizes
(
gptr
,
sizes
);
// return the pointer array as an integer array
const
int64_t
len
=
rst
.
size
();
NDArray
ptr_array
=
NDArray
::
Empty
({
len
},
DLDataType
{
kDLInt
,
64
,
1
},
DLContext
{
kDLCPU
,
0
});
int64_t
*
ptr_array_data
=
static_cast
<
int64_t
*>
(
ptr_array
->
data
);
for
(
size_t
i
=
0
;
i
<
rst
.
size
();
++
i
)
{
Graph
*
ptr
=
new
Graph
();
*
ptr
=
std
::
move
(
rst
[
i
]);
ptr_array_data
[
i
]
=
reinterpret_cast
<
std
::
intptr_t
>
(
ptr
);
}
*
rv
=
ptr_array
;
});
}
// namespace dgl
}
// namespace dgl
src/graph/graph_op.cc
View file @
2be55fb5
// Graph operation implementation
// Graph operation implementation
#include <dgl/graph_op.h>
#include <dgl/graph_op.h>
#include <algorithm>
namespace
dgl
{
namespace
dgl
{
...
@@ -16,4 +17,91 @@ Graph GraphOp::DisjointUnion(std::vector<const Graph*> graphs) {
...
@@ -16,4 +17,91 @@ Graph GraphOp::DisjointUnion(std::vector<const Graph*> graphs) {
return
rst
;
return
rst
;
}
}
std
::
vector
<
Graph
>
GraphOp
::
DisjointPartitionByNum
(
const
Graph
*
graph
,
int64_t
num
)
{
CHECK
(
num
!=
0
&&
graph
->
NumVertices
()
%
num
==
0
)
<<
"Number of partitions must evenly divide the number of nodes."
;
IdArray
sizes
=
IdArray
::
Empty
({
num
},
DLDataType
{
kDLInt
,
64
,
1
},
DLContext
{
kDLCPU
,
0
});
int64_t
*
sizes_data
=
static_cast
<
int64_t
*>
(
sizes
->
data
);
std
::
fill
(
sizes_data
,
sizes_data
+
num
,
graph
->
NumVertices
()
/
num
);
return
DisjointPartitionBySizes
(
graph
,
sizes
);
}
std
::
vector
<
Graph
>
GraphOp
::
DisjointPartitionBySizes
(
const
Graph
*
graph
,
IdArray
sizes
)
{
const
int64_t
len
=
sizes
->
shape
[
0
];
const
int64_t
*
sizes_data
=
static_cast
<
int64_t
*>
(
sizes
->
data
);
std
::
vector
<
int64_t
>
cumsum
;
cumsum
.
push_back
(
0
);
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
cumsum
.
push_back
(
cumsum
[
i
]
+
sizes_data
[
i
]);
}
CHECK_EQ
(
cumsum
[
len
],
graph
->
NumVertices
())
<<
"Sum of the given sizes must equal to the number of nodes."
;
dgl_id_t
node_offset
=
0
,
edge_offset
=
0
;
std
::
vector
<
Graph
>
rst
(
len
);
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
// copy adj
rst
[
i
].
adjlist_
.
insert
(
rst
[
i
].
adjlist_
.
end
(),
graph
->
adjlist_
.
begin
()
+
node_offset
,
graph
->
adjlist_
.
begin
()
+
node_offset
+
sizes_data
[
i
]);
rst
[
i
].
reverse_adjlist_
.
insert
(
rst
[
i
].
reverse_adjlist_
.
end
(),
graph
->
reverse_adjlist_
.
begin
()
+
node_offset
,
graph
->
reverse_adjlist_
.
begin
()
+
node_offset
+
sizes_data
[
i
]);
// relabel adjs
size_t
num_edges
=
0
;
for
(
auto
&
elist
:
rst
[
i
].
adjlist_
)
{
for
(
size_t
j
=
0
;
j
<
elist
.
succ
.
size
();
++
j
)
{
elist
.
succ
[
j
]
-=
node_offset
;
elist
.
edge_id
[
j
]
-=
edge_offset
;
}
num_edges
+=
elist
.
succ
.
size
();
}
for
(
auto
&
elist
:
rst
[
i
].
reverse_adjlist_
)
{
for
(
size_t
j
=
0
;
j
<
elist
.
succ
.
size
();
++
j
)
{
elist
.
succ
[
j
]
-=
node_offset
;
elist
.
edge_id
[
j
]
-=
edge_offset
;
}
}
// copy edges
rst
[
i
].
all_edges_src_
.
reserve
(
num_edges
);
rst
[
i
].
all_edges_dst_
.
reserve
(
num_edges
);
rst
[
i
].
num_edges_
=
num_edges
;
for
(
size_t
j
=
edge_offset
;
j
<
edge_offset
+
num_edges
;
++
j
)
{
rst
[
i
].
all_edges_src_
.
push_back
(
graph
->
all_edges_src_
[
j
]
-
node_offset
);
rst
[
i
].
all_edges_dst_
.
push_back
(
graph
->
all_edges_dst_
[
j
]
-
node_offset
);
}
// update offset
CHECK_EQ
(
rst
[
i
].
NumVertices
(),
sizes_data
[
i
]);
CHECK_EQ
(
rst
[
i
].
NumEdges
(),
num_edges
);
node_offset
+=
sizes_data
[
i
];
edge_offset
+=
num_edges
;
}
/*for (int64_t i = 0; i < len; ++i) {
rst[i].AddVertices(sizes_data[i]);
}
for (dgl_id_t eid = 0; eid < graph->num_edges_; ++eid) {
const dgl_id_t src = graph->all_edges_src_[eid];
const dgl_id_t dst = graph->all_edges_dst_[eid];
size_t src_select = 0, dst_select = 0;
for (size_t i = 1; i < cumsum.size(); ++i) { // TODO: replace with binary search
if (cumsum[i] > src) {
src_select = i;
break;
}
}
for (size_t i = 1; i < cumsum.size(); ++i) { // TODO: replace with binary search
if (cumsum[i] > dst) {
dst_select = i;
break;
}
}
if (src_select != dst_select) {
// the edge is ignored if across two partitions
continue;
}
const int64_t offset = cumsum[src_select - 1];
rst[src_select - 1].AddEdge(src - offset, dst - offset);
}*/
return
rst
;
}
}
// namespace dgl
}
// namespace dgl
tests/pytorch/test_graph_batch.py
View file @
2be55fb5
import
networkx
as
nx
import
networkx
as
nx
import
dgl
import
dgl
import
torch
import
torch
as
th
import
numpy
as
np
import
numpy
as
np
def
tree1
():
def
tree1
():
...
@@ -13,17 +13,13 @@ def tree1():
...
@@ -13,17 +13,13 @@ def tree1():
Edges are from leaves to root.
Edges are from leaves to root.
"""
"""
g
=
dgl
.
DGLGraph
()
g
=
dgl
.
DGLGraph
()
g
.
add_node
(
0
)
g
.
add_nodes
(
5
)
g
.
add_node
(
1
)
g
.
add_node
(
2
)
g
.
add_node
(
3
)
g
.
add_node
(
4
)
g
.
add_edge
(
3
,
1
)
g
.
add_edge
(
3
,
1
)
g
.
add_edge
(
4
,
1
)
g
.
add_edge
(
4
,
1
)
g
.
add_edge
(
1
,
0
)
g
.
add_edge
(
1
,
0
)
g
.
add_edge
(
2
,
0
)
g
.
add_edge
(
2
,
0
)
g
.
set_n_repr
(
t
orc
h
.
Tensor
([
0
,
1
,
2
,
3
,
4
]))
g
.
set_n_repr
(
th
.
Tensor
([
0
,
1
,
2
,
3
,
4
]))
g
.
set_e_repr
(
t
orc
h
.
randn
(
4
,
10
))
g
.
set_e_repr
(
th
.
randn
(
4
,
10
))
return
g
return
g
def
tree2
():
def
tree2
():
...
@@ -36,17 +32,13 @@ def tree2():
...
@@ -36,17 +32,13 @@ def tree2():
Edges are from leaves to root.
Edges are from leaves to root.
"""
"""
g
=
dgl
.
DGLGraph
()
g
=
dgl
.
DGLGraph
()
g
.
add_node
(
0
)
g
.
add_nodes
(
5
)
g
.
add_node
(
1
)
g
.
add_node
(
2
)
g
.
add_node
(
3
)
g
.
add_node
(
4
)
g
.
add_edge
(
2
,
4
)
g
.
add_edge
(
2
,
4
)
g
.
add_edge
(
0
,
4
)
g
.
add_edge
(
0
,
4
)
g
.
add_edge
(
4
,
1
)
g
.
add_edge
(
4
,
1
)
g
.
add_edge
(
3
,
1
)
g
.
add_edge
(
3
,
1
)
g
.
set_n_repr
(
t
orc
h
.
Tensor
([
0
,
1
,
2
,
3
,
4
]))
g
.
set_n_repr
(
th
.
Tensor
([
0
,
1
,
2
,
3
,
4
]))
g
.
set_e_repr
(
t
orc
h
.
randn
(
4
,
10
))
g
.
set_e_repr
(
th
.
randn
(
4
,
10
))
return
g
return
g
def
test_batch_unbatch
():
def
test_batch_unbatch
():
...
@@ -58,13 +50,36 @@ def test_batch_unbatch():
...
@@ -58,13 +50,36 @@ def test_batch_unbatch():
e2
=
t2
.
get_e_repr
()
e2
=
t2
.
get_e_repr
()
bg
=
dgl
.
batch
([
t1
,
t2
])
bg
=
dgl
.
batch
([
t1
,
t2
])
dgl
.
unbatch
(
bg
)
assert
bg
.
number_of_nodes
()
==
10
assert
bg
.
number_of_edges
()
==
8
assert
(
n1
.
equal
(
t1
.
get_n_repr
()))
assert
bg
.
batch_size
==
2
assert
(
n2
.
equal
(
t2
.
get_n_repr
()))
assert
bg
.
batch_num_nodes
==
[
5
,
5
]
assert
(
e1
.
equal
(
t1
.
get_e_repr
()))
assert
bg
.
batch_num_edges
==
[
4
,
4
]
assert
(
e2
.
equal
(
t2
.
get_e_repr
()))
tt1
,
tt2
=
dgl
.
unbatch
(
bg
)
assert
th
.
allclose
(
t1
.
get_n_repr
(),
tt1
.
get_n_repr
())
assert
th
.
allclose
(
t1
.
get_e_repr
(),
tt1
.
get_e_repr
())
assert
th
.
allclose
(
t2
.
get_n_repr
(),
tt2
.
get_n_repr
())
assert
th
.
allclose
(
t2
.
get_e_repr
(),
tt2
.
get_e_repr
())
def
test_batch_unbatch1
():
t1
=
tree1
()
t2
=
tree2
()
b1
=
dgl
.
batch
([
t1
,
t2
])
b2
=
dgl
.
batch
([
t2
,
b1
])
assert
b2
.
number_of_nodes
()
==
15
assert
b2
.
number_of_edges
()
==
12
assert
b2
.
batch_size
==
3
assert
b2
.
batch_num_nodes
==
[
5
,
5
,
5
]
assert
b2
.
batch_num_edges
==
[
4
,
4
,
4
]
s1
,
s2
,
s3
=
dgl
.
unbatch
(
b2
)
assert
th
.
allclose
(
t2
.
get_n_repr
(),
s1
.
get_n_repr
())
assert
th
.
allclose
(
t2
.
get_e_repr
(),
s1
.
get_e_repr
())
assert
th
.
allclose
(
t1
.
get_n_repr
(),
s2
.
get_n_repr
())
assert
th
.
allclose
(
t1
.
get_e_repr
(),
s2
.
get_e_repr
())
assert
th
.
allclose
(
t2
.
get_n_repr
(),
s3
.
get_n_repr
())
assert
th
.
allclose
(
t2
.
get_e_repr
(),
s3
.
get_e_repr
())
def
test_batch_sendrecv
():
def
test_batch_sendrecv
():
t1
=
tree1
()
t1
=
tree1
()
...
@@ -72,7 +87,7 @@ def test_batch_sendrecv():
...
@@ -72,7 +87,7 @@ def test_batch_sendrecv():
bg
=
dgl
.
batch
([
t1
,
t2
])
bg
=
dgl
.
batch
([
t1
,
t2
])
bg
.
register_message_func
(
lambda
src
,
edge
:
src
)
bg
.
register_message_func
(
lambda
src
,
edge
:
src
)
bg
.
register_reduce_func
(
lambda
node
,
msgs
:
t
orc
h
.
sum
(
msgs
,
1
))
bg
.
register_reduce_func
(
lambda
node
,
msgs
:
th
.
sum
(
msgs
,
1
))
e1
=
[(
3
,
1
),
(
4
,
1
)]
e1
=
[(
3
,
1
),
(
4
,
1
)]
e2
=
[(
2
,
4
),
(
0
,
4
)]
e2
=
[(
2
,
4
),
(
0
,
4
)]
...
@@ -95,7 +110,7 @@ def test_batch_propagate():
...
@@ -95,7 +110,7 @@ def test_batch_propagate():
bg
=
dgl
.
batch
([
t1
,
t2
])
bg
=
dgl
.
batch
([
t1
,
t2
])
bg
.
register_message_func
(
lambda
src
,
edge
:
src
)
bg
.
register_message_func
(
lambda
src
,
edge
:
src
)
bg
.
register_reduce_func
(
lambda
node
,
msgs
:
t
orc
h
.
sum
(
msgs
,
1
))
bg
.
register_reduce_func
(
lambda
node
,
msgs
:
th
.
sum
(
msgs
,
1
))
# get leaves.
# get leaves.
order
=
[]
order
=
[]
...
@@ -129,20 +144,21 @@ def test_batched_edge_ordering():
...
@@ -129,20 +144,21 @@ def test_batched_edge_ordering():
g1
.
add_nodes_from
([
0
,
1
,
2
,
3
,
4
,
5
])
g1
.
add_nodes_from
([
0
,
1
,
2
,
3
,
4
,
5
])
g1
.
add_edges_from
([(
4
,
5
),
(
4
,
3
),
(
2
,
3
),
(
2
,
1
),
(
0
,
1
)])
g1
.
add_edges_from
([(
4
,
5
),
(
4
,
3
),
(
2
,
3
),
(
2
,
1
),
(
0
,
1
)])
g1
.
edge_list
g1
.
edge_list
e1
=
t
orc
h
.
randn
(
5
,
10
)
e1
=
th
.
randn
(
5
,
10
)
g1
.
set_e_repr
(
e1
)
g1
.
set_e_repr
(
e1
)
g2
=
dgl
.
DGLGraph
()
g2
=
dgl
.
DGLGraph
()
g2
.
add_nodes_from
([
0
,
1
,
2
,
3
,
4
,
5
])
g2
.
add_nodes_from
([
0
,
1
,
2
,
3
,
4
,
5
])
g2
.
add_edges_from
([(
0
,
1
),
(
1
,
2
),
(
2
,
3
),
(
5
,
4
),
(
4
,
3
),
(
5
,
0
)])
g2
.
add_edges_from
([(
0
,
1
),
(
1
,
2
),
(
2
,
3
),
(
5
,
4
),
(
4
,
3
),
(
5
,
0
)])
e2
=
t
orc
h
.
randn
(
6
,
10
)
e2
=
th
.
randn
(
6
,
10
)
g2
.
set_e_repr
(
e2
)
g2
.
set_e_repr
(
e2
)
g
=
dgl
.
batch
([
g1
,
g2
])
g
=
dgl
.
batch
([
g1
,
g2
])
r1
=
g
.
get_e_repr
()[
g
.
get_edge_id
(
4
,
5
)]
r1
=
g
.
get_e_repr
()[
g
.
get_edge_id
(
4
,
5
)]
r2
=
g1
.
get_e_repr
()[
g1
.
get_edge_id
(
4
,
5
)]
r2
=
g1
.
get_e_repr
()[
g1
.
get_edge_id
(
4
,
5
)]
assert
t
orc
h
.
equal
(
r1
,
r2
)
assert
th
.
equal
(
r1
,
r2
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_batch_unbatch
()
test_batch_unbatch
()
test_batched_edge_ordering
()
test_batch_unbatch1
()
test_batch_sendrecv
()
#test_batched_edge_ordering()
test_batch_propagate
()
#test_batch_sendrecv()
#test_batch_propagate()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment