Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
2c325b2d
"src/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "7ba9cbc63e337f207c0707d1018d0d4e1c83ca51"
Unverified
Commit
2c325b2d
authored
Dec 07, 2023
by
Rhett Ying
Committed by
GitHub
Dec 07, 2023
Browse files
[GraphBolt] add node/edge_type_to_id into pickle (#6701)
parent
cbb6f502
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
63 additions
and
3 deletions
+63
-3
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+12
-0
graphbolt/src/python_binding.cc
graphbolt/src/python_binding.cc
+4
-0
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+38
-0
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
...n/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
+9
-3
No files found.
graphbolt/src/fused_csc_sampling_graph.cc
View file @
2c325b2d
...
@@ -232,6 +232,12 @@ void FusedCSCSamplingGraph::SetState(
...
@@ -232,6 +232,12 @@ void FusedCSCSamplingGraph::SetState(
if
(
independent_tensors
.
find
(
"type_per_edge"
)
!=
independent_tensors
.
end
())
{
if
(
independent_tensors
.
find
(
"type_per_edge"
)
!=
independent_tensors
.
end
())
{
type_per_edge_
=
independent_tensors
.
at
(
"type_per_edge"
);
type_per_edge_
=
independent_tensors
.
at
(
"type_per_edge"
);
}
}
if
(
state
.
find
(
"node_type_to_id"
)
!=
state
.
end
())
{
node_type_to_id_
=
DetensorizeDict
(
state
.
at
(
"node_type_to_id"
));
}
if
(
state
.
find
(
"edge_type_to_id"
)
!=
state
.
end
())
{
edge_type_to_id_
=
DetensorizeDict
(
state
.
at
(
"edge_type_to_id"
));
}
if
(
state
.
find
(
"edge_attributes"
)
!=
state
.
end
())
{
if
(
state
.
find
(
"edge_attributes"
)
!=
state
.
end
())
{
edge_attributes_
=
state
.
at
(
"edge_attributes"
);
edge_attributes_
=
state
.
at
(
"edge_attributes"
);
}
}
...
@@ -256,6 +262,12 @@ FusedCSCSamplingGraph::GetState() const {
...
@@ -256,6 +262,12 @@ FusedCSCSamplingGraph::GetState() const {
independent_tensors
.
insert
(
"type_per_edge"
,
type_per_edge_
.
value
());
independent_tensors
.
insert
(
"type_per_edge"
,
type_per_edge_
.
value
());
}
}
state
.
insert
(
"independent_tensors"
,
independent_tensors
);
state
.
insert
(
"independent_tensors"
,
independent_tensors
);
if
(
node_type_to_id_
.
has_value
())
{
state
.
insert
(
"node_type_to_id"
,
TensorizeDict
(
node_type_to_id_
).
value
());
}
if
(
edge_type_to_id_
.
has_value
())
{
state
.
insert
(
"edge_type_to_id"
,
TensorizeDict
(
edge_type_to_id_
).
value
());
}
if
(
edge_attributes_
.
has_value
())
{
if
(
edge_attributes_
.
has_value
())
{
state
.
insert
(
"edge_attributes"
,
edge_attributes_
.
value
());
state
.
insert
(
"edge_attributes"
,
edge_attributes_
.
value
());
}
}
...
...
graphbolt/src/python_binding.cc
View file @
2c325b2d
...
@@ -35,11 +35,15 @@ TORCH_LIBRARY(graphbolt, m) {
...
@@ -35,11 +35,15 @@ TORCH_LIBRARY(graphbolt, m) {
.
def
(
"indices"
,
&
FusedCSCSamplingGraph
::
Indices
)
.
def
(
"indices"
,
&
FusedCSCSamplingGraph
::
Indices
)
.
def
(
"node_type_offset"
,
&
FusedCSCSamplingGraph
::
NodeTypeOffset
)
.
def
(
"node_type_offset"
,
&
FusedCSCSamplingGraph
::
NodeTypeOffset
)
.
def
(
"type_per_edge"
,
&
FusedCSCSamplingGraph
::
TypePerEdge
)
.
def
(
"type_per_edge"
,
&
FusedCSCSamplingGraph
::
TypePerEdge
)
.
def
(
"node_type_to_id"
,
&
FusedCSCSamplingGraph
::
NodeTypeToID
)
.
def
(
"edge_type_to_id"
,
&
FusedCSCSamplingGraph
::
EdgeTypeToID
)
.
def
(
"edge_attributes"
,
&
FusedCSCSamplingGraph
::
EdgeAttributes
)
.
def
(
"edge_attributes"
,
&
FusedCSCSamplingGraph
::
EdgeAttributes
)
.
def
(
"set_csc_indptr"
,
&
FusedCSCSamplingGraph
::
SetCSCIndptr
)
.
def
(
"set_csc_indptr"
,
&
FusedCSCSamplingGraph
::
SetCSCIndptr
)
.
def
(
"set_indices"
,
&
FusedCSCSamplingGraph
::
SetIndices
)
.
def
(
"set_indices"
,
&
FusedCSCSamplingGraph
::
SetIndices
)
.
def
(
"set_node_type_offset"
,
&
FusedCSCSamplingGraph
::
SetNodeTypeOffset
)
.
def
(
"set_node_type_offset"
,
&
FusedCSCSamplingGraph
::
SetNodeTypeOffset
)
.
def
(
"set_type_per_edge"
,
&
FusedCSCSamplingGraph
::
SetTypePerEdge
)
.
def
(
"set_type_per_edge"
,
&
FusedCSCSamplingGraph
::
SetTypePerEdge
)
.
def
(
"set_node_type_to_id"
,
&
FusedCSCSamplingGraph
::
SetNodeTypeToID
)
.
def
(
"set_edge_type_to_id"
,
&
FusedCSCSamplingGraph
::
SetEdgeTypeToID
)
.
def
(
"set_edge_attributes"
,
&
FusedCSCSamplingGraph
::
SetEdgeAttributes
)
.
def
(
"set_edge_attributes"
,
&
FusedCSCSamplingGraph
::
SetEdgeAttributes
)
.
def
(
"in_subgraph"
,
&
FusedCSCSamplingGraph
::
InSubgraph
)
.
def
(
"in_subgraph"
,
&
FusedCSCSamplingGraph
::
InSubgraph
)
.
def
(
"sample_neighbors"
,
&
FusedCSCSamplingGraph
::
SampleNeighbors
)
.
def
(
"sample_neighbors"
,
&
FusedCSCSamplingGraph
::
SampleNeighbors
)
...
...
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
2c325b2d
...
@@ -255,6 +255,44 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -255,6 +255,44 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""Sets the edge type tensor if present."""
"""Sets the edge type tensor if present."""
self
.
_c_csc_graph
.
set_type_per_edge
(
type_per_edge
)
self
.
_c_csc_graph
.
set_type_per_edge
(
type_per_edge
)
@
property
def
node_type_to_id
(
self
)
->
Optional
[
Dict
[
str
,
int
]]:
"""Returns the node type to id dictionary if present.
Returns
-------
Dict[str, int] or None
If present, returns a dictionary mapping node type to node type
id.
"""
return
self
.
_c_csc_graph
.
node_type_to_id
()
@
node_type_to_id
.
setter
def
node_type_to_id
(
self
,
node_type_to_id
:
Optional
[
Dict
[
str
,
int
]]
)
->
None
:
"""Sets the node type to id dictionary if present."""
self
.
_c_csc_graph
.
set_node_type_to_id
(
node_type_to_id
)
@
property
def
edge_type_to_id
(
self
)
->
Optional
[
Dict
[
str
,
int
]]:
"""Returns the edge type to id dictionary if present.
Returns
-------
Dict[str, int] or None
If present, returns a dictionary mapping edge type to edge type
id.
"""
return
self
.
_c_csc_graph
.
edge_type_to_id
()
@
edge_type_to_id
.
setter
def
edge_type_to_id
(
self
,
edge_type_to_id
:
Optional
[
Dict
[
str
,
int
]]
)
->
None
:
"""Sets the edge type to id dictionary if present."""
self
.
_c_csc_graph
.
set_edge_type_to_id
(
edge_type_to_id
)
@
property
@
property
def
edge_attributes
(
self
)
->
Optional
[
Dict
[
str
,
torch
.
Tensor
]]:
def
edge_attributes
(
self
)
->
Optional
[
Dict
[
str
,
torch
.
Tensor
]]:
"""Returns the edge attributes dictionary.
"""Returns the edge attributes dictionary.
...
...
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
View file @
2c325b2d
...
@@ -376,9 +376,11 @@ def test_pickle_homo_graph(total_num_nodes, total_num_edges):
...
@@ -376,9 +376,11 @@ def test_pickle_homo_graph(total_num_nodes, total_num_edges):
assert
torch
.
equal
(
graph
.
csc_indptr
,
graph2
.
csc_indptr
)
assert
torch
.
equal
(
graph
.
csc_indptr
,
graph2
.
csc_indptr
)
assert
torch
.
equal
(
graph
.
indices
,
graph2
.
indices
)
assert
torch
.
equal
(
graph
.
indices
,
graph2
.
indices
)
assert
graph
.
metadata
is
None
and
graph2
.
metadata
is
None
assert
graph
.
node_type_offset
is
None
and
graph2
.
node_type_offset
is
None
assert
graph
.
node_type_offset
is
None
and
graph2
.
node_type_offset
is
None
assert
graph
.
type_per_edge
is
None
and
graph2
.
type_per_edge
is
None
assert
graph
.
type_per_edge
is
None
and
graph2
.
type_per_edge
is
None
assert
graph
.
node_type_to_id
is
None
and
graph2
.
node_type_to_id
is
None
assert
graph
.
edge_type_to_id
is
None
and
graph2
.
edge_type_to_id
is
None
assert
graph
.
edge_attributes
is
None
and
graph2
.
edge_attributes
is
None
@
unittest
.
skipIf
(
@
unittest
.
skipIf
(
...
@@ -425,8 +427,12 @@ def test_pickle_hetero_graph(
...
@@ -425,8 +427,12 @@ def test_pickle_hetero_graph(
assert
torch
.
equal
(
graph
.
indices
,
graph2
.
indices
)
assert
torch
.
equal
(
graph
.
indices
,
graph2
.
indices
)
assert
torch
.
equal
(
graph
.
node_type_offset
,
graph2
.
node_type_offset
)
assert
torch
.
equal
(
graph
.
node_type_offset
,
graph2
.
node_type_offset
)
assert
torch
.
equal
(
graph
.
type_per_edge
,
graph2
.
type_per_edge
)
assert
torch
.
equal
(
graph
.
type_per_edge
,
graph2
.
type_per_edge
)
assert
graph
.
metadata
.
node_type_to_id
==
graph2
.
metadata
.
node_type_to_id
assert
graph
.
node_type_to_id
.
keys
()
==
graph2
.
node_type_to_id
.
keys
()
assert
graph
.
metadata
.
edge_type_to_id
==
graph2
.
metadata
.
edge_type_to_id
for
i
in
graph
.
node_type_to_id
.
keys
():
assert
graph
.
node_type_to_id
[
i
]
==
graph2
.
node_type_to_id
[
i
]
assert
graph
.
edge_type_to_id
.
keys
()
==
graph2
.
edge_type_to_id
.
keys
()
for
i
in
graph
.
edge_type_to_id
.
keys
():
assert
graph
.
edge_type_to_id
[
i
]
==
graph2
.
edge_type_to_id
[
i
]
assert
graph
.
edge_attributes
.
keys
()
==
graph2
.
edge_attributes
.
keys
()
assert
graph
.
edge_attributes
.
keys
()
==
graph2
.
edge_attributes
.
keys
()
for
i
in
graph
.
edge_attributes
.
keys
():
for
i
in
graph
.
edge_attributes
.
keys
():
assert
torch
.
equal
(
graph
.
edge_attributes
[
i
],
graph2
.
edge_attributes
[
i
])
assert
torch
.
equal
(
graph
.
edge_attributes
[
i
],
graph2
.
edge_attributes
[
i
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment