Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
4b456507
Unverified
Commit
4b456507
authored
Jul 11, 2023
by
peizhou001
Committed by
GitHub
Jul 11, 2023
Browse files
[Graphbolt] Add edge attributes (#5966)
parent
aa795b28
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
82 additions
and
18 deletions
+82
-18
graphbolt/include/graphbolt/csc_sampling_graph.h
graphbolt/include/graphbolt/csc_sampling_graph.h
+17
-2
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+13
-6
graphbolt/src/python_binding.cc
graphbolt/src/python_binding.cc
+1
-0
python/dgl/graphbolt/impl/csc_sampling_graph.py
python/dgl/graphbolt/impl/csc_sampling_graph.py
+28
-3
tests/python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
.../python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
+22
-6
tests/python/pytorch/graphbolt/test_ondisk_dataset.py
tests/python/pytorch/graphbolt/test_ondisk_dataset.py
+1
-1
No files found.
graphbolt/include/graphbolt/csc_sampling_graph.h
View file @
4b456507
...
...
@@ -33,6 +33,7 @@ namespace sampling {
*/
class
CSCSamplingGraph
:
public
torch
::
CustomClassHolder
{
public:
using
EdgeAttrMap
=
torch
::
Dict
<
std
::
string
,
torch
::
Tensor
>
;
/** @brief Default constructor. */
CSCSamplingGraph
()
=
default
;
...
...
@@ -48,7 +49,8 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
CSCSamplingGraph
(
const
torch
::
Tensor
&
indptr
,
const
torch
::
Tensor
&
indices
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_type_offset
,
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
);
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
EdgeAttrMap
>&
edge_attributes
);
/**
* @brief Create a homogeneous CSC graph from tensors of CSC format.
...
...
@@ -64,7 +66,8 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
static
c10
::
intrusive_ptr
<
CSCSamplingGraph
>
FromCSC
(
const
torch
::
Tensor
&
indptr
,
const
torch
::
Tensor
&
indices
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_type_offset
,
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
);
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
EdgeAttrMap
>&
edge_attributes
);
/** @brief Get the number of nodes. */
int64_t
NumNodes
()
const
{
return
indptr_
.
size
(
0
)
-
1
;
}
...
...
@@ -88,6 +91,11 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
return
type_per_edge_
;
}
/** @brief Get the edge attributes dictionary. */
inline
const
torch
::
optional
<
EdgeAttrMap
>
EdgeAttributes
()
const
{
return
edge_attributes_
;
}
/**
* @brief Magic number to indicate graph version in serialize/deserialize
* stage.
...
...
@@ -231,6 +239,13 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
*/
torch
::
optional
<
torch
::
Tensor
>
type_per_edge_
;
/**
* @brief A dictionary of edge attributes. Each key represents the attribute's
* name, while the corresponding value holds the attribute's specific value.
* The length of each value should match the total number of edges."
*/
torch
::
optional
<
EdgeAttrMap
>
edge_attributes_
;
/**
* @brief Maximum number of bytes used to serialize the metadata of the
* member tensors, including tensor shape and dtype. The constant is estimated
...
...
graphbolt/src/csc_sampling_graph.cc
View file @
4b456507
...
...
@@ -20,11 +20,13 @@ namespace sampling {
CSCSamplingGraph
::
CSCSamplingGraph
(
const
torch
::
Tensor
&
indptr
,
const
torch
::
Tensor
&
indices
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_type_offset
,
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
)
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
EdgeAttrMap
>&
edge_attributes
)
:
indptr_
(
indptr
),
indices_
(
indices
),
node_type_offset_
(
node_type_offset
),
type_per_edge_
(
type_per_edge
)
{
type_per_edge_
(
type_per_edge
),
edge_attributes_
(
edge_attributes
)
{
TORCH_CHECK
(
indptr
.
dim
()
==
1
);
TORCH_CHECK
(
indices
.
dim
()
==
1
);
TORCH_CHECK
(
indptr
.
device
()
==
indices
.
device
());
...
...
@@ -33,7 +35,8 @@ CSCSamplingGraph::CSCSamplingGraph(
c10
::
intrusive_ptr
<
CSCSamplingGraph
>
CSCSamplingGraph
::
FromCSC
(
const
torch
::
Tensor
&
indptr
,
const
torch
::
Tensor
&
indices
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_type_offset
,
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
)
{
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
EdgeAttrMap
>&
edge_attributes
)
{
if
(
node_type_offset
.
has_value
())
{
auto
&
offset
=
node_type_offset
.
value
();
TORCH_CHECK
(
offset
.
dim
()
==
1
);
...
...
@@ -42,9 +45,13 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSC(
TORCH_CHECK
(
type_per_edge
.
value
().
dim
()
==
1
);
TORCH_CHECK
(
type_per_edge
.
value
().
size
(
0
)
==
indices
.
size
(
0
));
}
if
(
edge_attributes
.
has_value
())
{
for
(
const
auto
&
pair
:
edge_attributes
.
value
())
{
TORCH_CHECK
(
pair
.
value
().
size
(
0
)
==
indices
.
size
(
0
));
}
}
return
c10
::
make_intrusive
<
CSCSamplingGraph
>
(
indptr
,
indices
,
node_type_offset
,
type_per_edge
);
indptr
,
indices
,
node_type_offset
,
type_per_edge
,
edge_attributes
);
}
void
CSCSamplingGraph
::
Load
(
torch
::
serialize
::
InputArchive
&
archive
)
{
...
...
@@ -217,7 +224,7 @@ CSCSamplingGraph::BuildGraphFromSharedMemoryTensors(
auto
&
optional_tensors
=
std
::
get
<
2
>
(
shared_memory_tensors
);
auto
graph
=
c10
::
make_intrusive
<
CSCSamplingGraph
>
(
optional_tensors
[
0
].
value
(),
optional_tensors
[
1
].
value
(),
optional_tensors
[
2
],
optional_tensors
[
3
]);
optional_tensors
[
2
],
optional_tensors
[
3
]
,
torch
::
nullopt
);
graph
->
tensor_meta_shm_
=
std
::
move
(
std
::
get
<
0
>
(
shared_memory_tensors
));
graph
->
tensor_data_shm_
=
std
::
move
(
std
::
get
<
1
>
(
shared_memory_tensors
));
return
graph
;
...
...
graphbolt/src/python_binding.cc
View file @
4b456507
...
...
@@ -28,6 +28,7 @@ TORCH_LIBRARY(graphbolt, m) {
.
def
(
"indices"
,
&
CSCSamplingGraph
::
Indices
)
.
def
(
"node_type_offset"
,
&
CSCSamplingGraph
::
NodeTypeOffset
)
.
def
(
"type_per_edge"
,
&
CSCSamplingGraph
::
TypePerEdge
)
.
def
(
"edge_attributes"
,
&
CSCSamplingGraph
::
EdgeAttributes
)
.
def
(
"in_subgraph"
,
&
CSCSamplingGraph
::
InSubgraph
)
.
def
(
"sample_neighbors"
,
&
CSCSamplingGraph
::
SampleNeighbors
)
.
def
(
...
...
python/dgl/graphbolt/impl/csc_sampling_graph.py
View file @
4b456507
...
...
@@ -163,6 +163,20 @@ class CSCSamplingGraph:
"""
return
self
.
_c_csc_graph
.
type_per_edge
()
@
property
def
edge_attributes
(
self
)
->
Optional
[
Dict
[
str
,
torch
.
Tensor
]]:
"""Returns the edge attributes dictionary.
Returns
-------
torch.Tensor or None
If present, returns a dictionary of edge attributes. Each key
represents the attribute's name, while the corresponding value
holds the attribute's specific value. The length of each value
should match the total number of edges."
"""
return
self
.
_c_csc_graph
.
edge_attributes
()
@
property
def
metadata
(
self
)
->
Optional
[
GraphMetadata
]:
"""Returns the metadata of the graph.
...
...
@@ -383,6 +397,7 @@ def from_csc(
indices
:
torch
.
Tensor
,
node_type_offset
:
Optional
[
torch
.
tensor
]
=
None
,
type_per_edge
:
Optional
[
torch
.
tensor
]
=
None
,
edge_attributes
:
Optional
[
Dict
[
str
,
torch
.
tensor
]]
=
None
,
metadata
:
Optional
[
GraphMetadata
]
=
None
,
)
->
CSCSamplingGraph
:
"""Create a CSCSamplingGraph object from a CSC representation.
...
...
@@ -399,6 +414,8 @@ def from_csc(
Offset of node types in the graph, by default None.
type_per_edge : Optional[torch.tensor], optional
Type ids of each edge in the graph, by default None.
edge_attributes: Optional[Dict[str, torch.tensor]], optional
Edge attributes of the graph, by default None.
metadata: Optional[GraphMetadata], optional
Metadata of the graph, by default None.
Returns
...
...
@@ -416,7 +433,7 @@ def from_csc(
>>> node_type_offset = torch.tensor([0, 1, 2, 3])
>>> type_per_edge = torch.tensor([0, 1, 0, 1, 1, 0, 0])
>>> graph = graphbolt.from_csc(csc_indptr, indices, node_type_offset,
\
>>> type_per_edge, metadata)
>>> type_per_edge,
None,
metadata)
>>> print(graph)
CSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
indices=tensor([1, 3, 0, 1, 2, 0, 3]),
...
...
@@ -428,7 +445,11 @@ def from_csc(
),
"node_type_offset length should be |ntypes| + 1."
return
CSCSamplingGraph
(
torch
.
ops
.
graphbolt
.
from_csc
(
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
edge_attributes
,
),
metadata
,
)
...
...
@@ -535,7 +556,11 @@ def from_dglgraph(g: DGLGraph) -> CSCSamplingGraph:
return
CSCSamplingGraph
(
torch
.
ops
.
graphbolt
.
from_csc
(
indptr
,
indices
,
node_type_offset
,
type_per_edge
indptr
,
indices
,
node_type_offset
,
type_per_edge
,
None
,
),
metadata
,
)
tests/python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
View file @
4b456507
...
...
@@ -53,6 +53,7 @@ def test_hetero_empty_graph(num_nodes):
indices
,
node_type_offset
,
type_per_edge
,
None
,
metadata
,
)
assert
graph
.
num_edges
==
0
...
...
@@ -107,7 +108,11 @@ def test_metadata_with_etype_exception(etypes):
)
def
test_homo_graph
(
num_nodes
,
num_edges
):
csc_indptr
,
indices
=
gbt
.
random_homo_graph
(
num_nodes
,
num_edges
)
graph
=
gb
.
from_csc
(
csc_indptr
,
indices
)
edge_attributes
=
{
"A1"
:
torch
.
randn
(
num_edges
),
"A2"
:
torch
.
randn
(
num_edges
),
}
graph
=
gb
.
from_csc
(
csc_indptr
,
indices
,
edge_attributes
=
edge_attributes
)
assert
graph
.
num_nodes
==
num_nodes
assert
graph
.
num_edges
==
num_edges
...
...
@@ -115,6 +120,7 @@ def test_homo_graph(num_nodes, num_edges):
assert
torch
.
equal
(
csc_indptr
,
graph
.
csc_indptr
)
assert
torch
.
equal
(
indices
,
graph
.
indices
)
assert
graph
.
edge_attributes
==
edge_attributes
assert
graph
.
metadata
is
None
assert
graph
.
node_type_offset
is
None
assert
graph
.
type_per_edge
is
None
...
...
@@ -136,8 +142,17 @@ def test_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
type_per_edge
,
metadata
,
)
=
gbt
.
random_hetero_graph
(
num_nodes
,
num_edges
,
num_ntypes
,
num_etypes
)
edge_attributes
=
{
"A1"
:
torch
.
randn
(
num_edges
),
"A2"
:
torch
.
randn
(
num_edges
),
}
graph
=
gb
.
from_csc
(
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
metadata
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
edge_attributes
,
metadata
,
)
assert
graph
.
num_nodes
==
num_nodes
...
...
@@ -147,6 +162,7 @@ def test_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
assert
torch
.
equal
(
indices
,
graph
.
indices
)
assert
torch
.
equal
(
node_type_offset
,
graph
.
node_type_offset
)
assert
torch
.
equal
(
type_per_edge
,
graph
.
type_per_edge
)
assert
graph
.
edge_attributes
==
edge_attributes
assert
metadata
.
node_type_to_id
==
graph
.
metadata
.
node_type_to_id
assert
metadata
.
edge_type_to_id
==
graph
.
metadata
.
edge_type_to_id
...
...
@@ -170,7 +186,7 @@ def test_node_type_offset_wrong_legnth(node_type_offset):
)
with
pytest
.
raises
(
Exception
):
gb
.
from_csc
(
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
metadata
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
None
,
metadata
)
...
...
@@ -218,7 +234,7 @@ def test_load_save_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
metadata
,
)
=
gbt
.
random_hetero_graph
(
num_nodes
,
num_edges
,
num_ntypes
,
num_etypes
)
graph
=
gb
.
from_csc
(
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
metadata
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
None
,
metadata
)
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
...
...
@@ -323,7 +339,7 @@ def test_in_subgraph_heterogeneous():
# Construct CSCSamplingGraph.
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
graph
=
gb
.
from_csc
(
indptr
,
indices
,
node_type_offset
,
type_per_edge
,
metadata
indptr
,
indices
,
node_type_offset
,
type_per_edge
,
None
,
metadata
)
# Extract in subgraph.
...
...
@@ -662,7 +678,7 @@ def test_hetero_graph_on_shared_memory(
metadata
,
)
=
gbt
.
random_hetero_graph
(
num_nodes
,
num_edges
,
num_ntypes
,
num_etypes
)
graph
=
gb
.
from_csc
(
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
metadata
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
None
,
metadata
)
shm_name
=
"test_hetero_g"
...
...
tests/python/pytorch/graphbolt/test_ondisk_dataset.py
View file @
4b456507
...
...
@@ -684,7 +684,7 @@ def test_OnDiskDataset_Graph_heterogeneous():
metadata
,
)
=
gbt
.
random_hetero_graph
(
1000
,
10
*
1000
,
3
,
4
)
graph
=
gb
.
from_csc
(
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
metadata
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
None
,
metadata
)
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment