Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
e181ef15
Unverified
Commit
e181ef15
authored
Dec 15, 2023
by
Rhett Ying
Committed by
GitHub
Dec 15, 2023
Browse files
[GraphBolt] add node_attributes into FusedCSCSamplingGraph (#6757)
parent
cad7caeb
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
333 additions
and
51 deletions
+333
-51
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
+28
-5
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+44
-2
graphbolt/src/index_select.cc
graphbolt/src/index_select.cc
+1
-6
graphbolt/src/python_binding.cc
graphbolt/src/python_binding.cc
+2
-0
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+34
-1
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
...n/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
+224
-37
No files found.
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
View file @
e181ef15
...
...
@@ -50,6 +50,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
public:
using
NodeTypeToIDMap
=
torch
::
Dict
<
std
::
string
,
int64_t
>
;
using
EdgeTypeToIDMap
=
torch
::
Dict
<
std
::
string
,
int64_t
>
;
using
NodeAttrMap
=
torch
::
Dict
<
std
::
string
,
torch
::
Tensor
>
;
using
EdgeAttrMap
=
torch
::
Dict
<
std
::
string
,
torch
::
Tensor
>
;
/** @brief Default constructor. */
FusedCSCSamplingGraph
()
=
default
;
...
...
@@ -66,16 +67,18 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* present.
* @param edge_type_to_id A dictionary mapping edge type names to type IDs, if
* present.
* @param node_attributes A dictionary of node attributes, if present.
* @param edge_attributes A dictionary of edge attributes, if present.
*
*/
FusedCSCSamplingGraph
(
const
torch
::
Tensor
&
indptr
,
const
torch
::
Tensor
&
indices
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_type_offset
,
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
NodeTypeToIDMap
>&
node_type_to_id
,
const
torch
::
optional
<
EdgeTypeToIDMap
>&
edge_type_to_id
,
const
torch
::
optional
<
EdgeAttrMap
>&
edge_attributes
);
const
torch
::
optional
<
torch
::
Tensor
>&
node_type_offset
=
torch
::
nullopt
,
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
=
torch
::
nullopt
,
const
torch
::
optional
<
NodeTypeToIDMap
>&
node_type_to_id
=
torch
::
nullopt
,
const
torch
::
optional
<
EdgeTypeToIDMap
>&
edge_type_to_id
=
torch
::
nullopt
,
const
torch
::
optional
<
NodeAttrMap
>&
node_attributes
=
torch
::
nullopt
,
const
torch
::
optional
<
EdgeAttrMap
>&
edge_attributes
=
torch
::
nullopt
);
/**
* @brief Create a fused CSC graph from tensors of CSC format.
...
...
@@ -89,6 +92,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* present.
* @param edge_type_to_id A dictionary mapping edge type names to type IDs, if
* present.
* @param node_attributes A dictionary of node attributes, if present.
* @param edge_attributes A dictionary of edge attributes, if present.
*
* @return FusedCSCSamplingGraph
...
...
@@ -99,6 +103,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
NodeTypeToIDMap
>&
node_type_to_id
,
const
torch
::
optional
<
EdgeTypeToIDMap
>&
edge_type_to_id
,
const
torch
::
optional
<
NodeAttrMap
>&
node_attributes
,
const
torch
::
optional
<
EdgeAttrMap
>&
edge_attributes
);
/** @brief Get the number of nodes. */
...
...
@@ -139,6 +144,11 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
return
edge_type_to_id_
;
}
/** @brief Get the node attributes dictionary. */
inline
const
torch
::
optional
<
EdgeAttrMap
>
NodeAttributes
()
const
{
return
node_attributes_
;
}
/** @brief Get the edge attributes dictionary. */
inline
const
torch
::
optional
<
EdgeAttrMap
>
EdgeAttributes
()
const
{
return
edge_attributes_
;
...
...
@@ -180,6 +190,12 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
edge_type_to_id_
=
edge_type_to_id
;
}
/** @brief Set the node attributes dictionary. */
inline
void
SetNodeAttributes
(
const
torch
::
optional
<
EdgeAttrMap
>&
node_attributes
)
{
node_attributes_
=
node_attributes
;
}
/** @brief Set the edge attributes dictionary. */
inline
void
SetEdgeAttributes
(
const
torch
::
optional
<
EdgeAttrMap
>&
edge_attributes
)
{
...
...
@@ -367,6 +383,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
*/
torch
::
optional
<
EdgeTypeToIDMap
>
edge_type_to_id_
;
/**
* @brief A dictionary of node attributes. Each key represents the attribute's
* name, while the corresponding value holds the attribute's specific value.
* The length of each value should match the total number of nodes."
*/
torch
::
optional
<
NodeAttrMap
>
node_attributes_
;
/**
* @brief A dictionary of edge attributes. Each key represents the attribute's
* name, while the corresponding value holds the attribute's specific value.
...
...
graphbolt/src/fused_csc_sampling_graph.cc
View file @
e181ef15
...
...
@@ -56,6 +56,7 @@ FusedCSCSamplingGraph::FusedCSCSamplingGraph(
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
NodeTypeToIDMap
>&
node_type_to_id
,
const
torch
::
optional
<
EdgeTypeToIDMap
>&
edge_type_to_id
,
const
torch
::
optional
<
NodeAttrMap
>&
node_attributes
,
const
torch
::
optional
<
EdgeAttrMap
>&
edge_attributes
)
:
indptr_
(
indptr
),
indices_
(
indices
),
...
...
@@ -63,6 +64,7 @@ FusedCSCSamplingGraph::FusedCSCSamplingGraph(
type_per_edge_
(
type_per_edge
),
node_type_to_id_
(
node_type_to_id
),
edge_type_to_id_
(
edge_type_to_id
),
node_attributes_
(
node_attributes
),
edge_attributes_
(
edge_attributes
)
{
TORCH_CHECK
(
indptr
.
dim
()
==
1
);
TORCH_CHECK
(
indices
.
dim
()
==
1
);
...
...
@@ -75,6 +77,7 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::Create(
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
NodeTypeToIDMap
>&
node_type_to_id
,
const
torch
::
optional
<
EdgeTypeToIDMap
>&
edge_type_to_id
,
const
torch
::
optional
<
NodeAttrMap
>&
node_attributes
,
const
torch
::
optional
<
EdgeAttrMap
>&
edge_attributes
)
{
if
(
node_type_offset
.
has_value
())
{
auto
&
offset
=
node_type_offset
.
value
();
...
...
@@ -89,6 +92,11 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::Create(
TORCH_CHECK
(
type_per_edge
.
value
().
size
(
0
)
==
indices
.
size
(
0
));
TORCH_CHECK
(
edge_type_to_id
.
has_value
());
}
if
(
node_attributes
.
has_value
())
{
for
(
const
auto
&
pair
:
node_attributes
.
value
())
{
TORCH_CHECK
(
pair
.
value
().
size
(
0
)
==
indptr
.
size
(
0
)
-
1
);
}
}
if
(
edge_attributes
.
has_value
())
{
for
(
const
auto
&
pair
:
edge_attributes
.
value
())
{
TORCH_CHECK
(
pair
.
value
().
size
(
0
)
==
indices
.
size
(
0
));
...
...
@@ -96,7 +104,7 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::Create(
}
return
c10
::
make_intrusive
<
FusedCSCSamplingGraph
>
(
indptr
,
indices
,
node_type_offset
,
type_per_edge
,
node_type_to_id
,
edge_type_to_id
,
edge_attributes
);
edge_type_to_id
,
node_attributes
,
edge_attributes
);
}
void
FusedCSCSamplingGraph
::
Load
(
torch
::
serialize
::
InputArchive
&
archive
)
{
...
...
@@ -150,6 +158,25 @@ void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
edge_type_to_id_
=
std
::
move
(
edge_type_to_id
);
}
// Optional node attributes.
torch
::
IValue
has_node_attributes
;
if
(
archive
.
try_read
(
"FusedCSCSamplingGraph/has_node_attributes"
,
has_node_attributes
)
&&
has_node_attributes
.
toBool
())
{
torch
::
Dict
<
torch
::
IValue
,
torch
::
IValue
>
generic_dict
=
read_from_archive
(
archive
,
"FusedCSCSamplingGraph/node_attributes"
)
.
toGenericDict
();
NodeAttrMap
target_dict
;
for
(
const
auto
&
pair
:
generic_dict
)
{
std
::
string
key
=
pair
.
key
().
toStringRef
();
torch
::
Tensor
value
=
pair
.
value
().
toTensor
();
// Use move to avoid copy.
target_dict
.
insert
(
std
::
move
(
key
),
std
::
move
(
value
));
}
// Same as above.
node_attributes_
=
std
::
move
(
target_dict
);
}
// Optional edge attributes.
torch
::
IValue
has_edge_attributes
;
if
(
archive
.
try_read
(
...
...
@@ -203,6 +230,13 @@ void FusedCSCSamplingGraph::Save(
archive
.
write
(
"FusedCSCSamplingGraph/edge_type_to_id"
,
edge_type_to_id_
.
value
());
}
archive
.
write
(
"FusedCSCSamplingGraph/has_node_attributes"
,
node_attributes_
.
has_value
());
if
(
node_attributes_
)
{
archive
.
write
(
"FusedCSCSamplingGraph/node_attributes"
,
node_attributes_
.
value
());
}
archive
.
write
(
"FusedCSCSamplingGraph/has_edge_attributes"
,
edge_attributes_
.
has_value
());
...
...
@@ -238,6 +272,9 @@ void FusedCSCSamplingGraph::SetState(
if
(
state
.
find
(
"edge_type_to_id"
)
!=
state
.
end
())
{
edge_type_to_id_
=
DetensorizeDict
(
state
.
at
(
"edge_type_to_id"
));
}
if
(
state
.
find
(
"node_attributes"
)
!=
state
.
end
())
{
node_attributes_
=
state
.
at
(
"node_attributes"
);
}
if
(
state
.
find
(
"edge_attributes"
)
!=
state
.
end
())
{
edge_attributes_
=
state
.
at
(
"edge_attributes"
);
}
...
...
@@ -268,6 +305,9 @@ FusedCSCSamplingGraph::GetState() const {
if
(
edge_type_to_id_
.
has_value
())
{
state
.
insert
(
"edge_type_to_id"
,
TensorizeDict
(
edge_type_to_id_
).
value
());
}
if
(
node_attributes_
.
has_value
())
{
state
.
insert
(
"node_attributes"
,
node_attributes_
.
value
());
}
if
(
edge_attributes_
.
has_value
())
{
state
.
insert
(
"edge_attributes"
,
edge_attributes_
.
value
());
}
...
...
@@ -596,10 +636,11 @@ BuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) {
auto
type_per_edge
=
helper
.
ReadTorchTensor
();
auto
node_type_to_id
=
DetensorizeDict
(
helper
.
ReadTorchTensorDict
());
auto
edge_type_to_id
=
DetensorizeDict
(
helper
.
ReadTorchTensorDict
());
auto
node_attributes
=
helper
.
ReadTorchTensorDict
();
auto
edge_attributes
=
helper
.
ReadTorchTensorDict
();
auto
graph
=
c10
::
make_intrusive
<
FusedCSCSamplingGraph
>
(
indptr
.
value
(),
indices
.
value
(),
node_type_offset
,
type_per_edge
,
node_type_to_id
,
edge_type_to_id
,
edge_attributes
);
node_type_to_id
,
edge_type_to_id
,
node_attributes
,
edge_attributes
);
auto
shared_memory
=
helper
.
ReleaseSharedMemory
();
graph
->
HoldSharedMemoryObject
(
std
::
move
(
shared_memory
.
first
),
std
::
move
(
shared_memory
.
second
));
...
...
@@ -616,6 +657,7 @@ FusedCSCSamplingGraph::CopyToSharedMemory(
helper
.
WriteTorchTensor
(
type_per_edge_
);
helper
.
WriteTorchTensorDict
(
TensorizeDict
(
node_type_to_id_
));
helper
.
WriteTorchTensorDict
(
TensorizeDict
(
edge_type_to_id_
));
helper
.
WriteTorchTensorDict
(
node_attributes_
);
helper
.
WriteTorchTensorDict
(
edge_attributes_
);
helper
.
Flush
();
return
BuildGraphFromSharedMemoryHelper
(
std
::
move
(
helper
));
...
...
graphbolt/src/index_select.cc
View file @
e181ef15
...
...
@@ -43,12 +43,7 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(
TORCH_CHECK
(
c10
::
isIntegralType
(
indices
.
scalar_type
(),
false
),
"IndexSelectCSC is not implemented to slice noninteger types yet."
);
torch
::
optional
<
torch
::
Tensor
>
temp
;
torch
::
optional
<
sampling
::
FusedCSCSamplingGraph
::
NodeTypeToIDMap
>
temp2
;
torch
::
optional
<
sampling
::
FusedCSCSamplingGraph
::
EdgeTypeToIDMap
>
temp3
;
torch
::
optional
<
sampling
::
FusedCSCSamplingGraph
::
EdgeAttrMap
>
temp4
;
sampling
::
FusedCSCSamplingGraph
g
(
indptr
,
indices
,
temp
,
temp
,
temp2
,
temp3
,
temp4
);
sampling
::
FusedCSCSamplingGraph
g
(
indptr
,
indices
);
const
auto
res
=
g
.
InSubgraph
(
nodes
);
return
std
::
make_tuple
(
res
->
indptr
,
res
->
indices
);
}
...
...
graphbolt/src/python_binding.cc
View file @
e181ef15
...
...
@@ -37,6 +37,7 @@ TORCH_LIBRARY(graphbolt, m) {
.
def
(
"type_per_edge"
,
&
FusedCSCSamplingGraph
::
TypePerEdge
)
.
def
(
"node_type_to_id"
,
&
FusedCSCSamplingGraph
::
NodeTypeToID
)
.
def
(
"edge_type_to_id"
,
&
FusedCSCSamplingGraph
::
EdgeTypeToID
)
.
def
(
"node_attributes"
,
&
FusedCSCSamplingGraph
::
NodeAttributes
)
.
def
(
"edge_attributes"
,
&
FusedCSCSamplingGraph
::
EdgeAttributes
)
.
def
(
"set_csc_indptr"
,
&
FusedCSCSamplingGraph
::
SetCSCIndptr
)
.
def
(
"set_indices"
,
&
FusedCSCSamplingGraph
::
SetIndices
)
...
...
@@ -44,6 +45,7 @@ TORCH_LIBRARY(graphbolt, m) {
.
def
(
"set_type_per_edge"
,
&
FusedCSCSamplingGraph
::
SetTypePerEdge
)
.
def
(
"set_node_type_to_id"
,
&
FusedCSCSamplingGraph
::
SetNodeTypeToID
)
.
def
(
"set_edge_type_to_id"
,
&
FusedCSCSamplingGraph
::
SetEdgeTypeToID
)
.
def
(
"set_node_attributes"
,
&
FusedCSCSamplingGraph
::
SetNodeAttributes
)
.
def
(
"set_edge_attributes"
,
&
FusedCSCSamplingGraph
::
SetEdgeAttributes
)
.
def
(
"in_subgraph"
,
&
FusedCSCSamplingGraph
::
InSubgraph
)
.
def
(
"sample_neighbors"
,
&
FusedCSCSamplingGraph
::
SampleNeighbors
)
...
...
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
e181ef15
...
...
@@ -279,6 +279,27 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""Sets the edge type to id dictionary if present."""
self
.
_c_csc_graph
.
set_edge_type_to_id
(
edge_type_to_id
)
@
property
def
node_attributes
(
self
)
->
Optional
[
Dict
[
str
,
torch
.
Tensor
]]:
"""Returns the node attributes dictionary.
Returns
-------
Dict[str, torch.Tensor] or None
If present, returns a dictionary of node attributes. Each key
represents the attribute's name, while the corresponding value
holds the attribute's specific value. The length of each value
should match the total number of nodes."
"""
return
self
.
_c_csc_graph
.
node_attributes
()
@
node_attributes
.
setter
def
node_attributes
(
self
,
node_attributes
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
)
->
None
:
"""Sets the node attributes dictionary."""
self
.
_c_csc_graph
.
set_node_attributes
(
node_attributes
)
@
property
def
edge_attributes
(
self
)
->
Optional
[
Dict
[
str
,
torch
.
Tensor
]]:
"""Returns the edge attributes dictionary.
...
...
@@ -892,6 +913,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
self
.
type_per_edge
=
recursive_apply
(
self
.
type_per_edge
,
lambda
x
:
_to
(
x
,
device
)
)
self
.
node_attributes
=
recursive_apply
(
self
.
node_attributes
,
lambda
x
:
_to
(
x
,
device
)
)
self
.
edge_attributes
=
recursive_apply
(
self
.
edge_attributes
,
lambda
x
:
_to
(
x
,
device
)
)
...
...
@@ -906,6 +930,7 @@ def fused_csc_sampling_graph(
type_per_edge
:
Optional
[
torch
.
tensor
]
=
None
,
node_type_to_id
:
Optional
[
Dict
[
str
,
int
]]
=
None
,
edge_type_to_id
:
Optional
[
Dict
[
str
,
int
]]
=
None
,
node_attributes
:
Optional
[
Dict
[
str
,
torch
.
tensor
]]
=
None
,
edge_attributes
:
Optional
[
Dict
[
str
,
torch
.
tensor
]]
=
None
,
)
->
FusedCSCSamplingGraph
:
"""Create a FusedCSCSamplingGraph object from a CSC representation.
...
...
@@ -926,6 +951,8 @@ def fused_csc_sampling_graph(
Map node types to ids, by default None.
edge_type_to_id : Optional[Dict[str, int]], optional
Map edge types to ids, by default None.
node_attributes: Optional[Dict[str, torch.tensor]], optional
Node attributes of the graph, by default None.
edge_attributes: Optional[Dict[str, torch.tensor]], optional
Edge attributes of the graph, by default None.
...
...
@@ -946,7 +973,7 @@ def fused_csc_sampling_graph(
... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge,
... node_type_to_id=ntypes, edge_type_to_id=etypes,
... edge_attributes=None,)
...
node_attributes=None,
edge_attributes=None,)
>>> print(graph)
FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
indices=tensor([1, 3, 0, 1, 2, 0, 3]),
...
...
@@ -997,6 +1024,7 @@ def fused_csc_sampling_graph(
type_per_edge
,
node_type_to_id
,
edge_type_to_id
,
node_attributes
,
edge_attributes
,
),
)
...
...
@@ -1037,6 +1065,8 @@ def _csc_sampling_graph_str(graph: FusedCSCSamplingGraph) -> str:
meta_str
+=
f
", node_type_to_id=
{
graph
.
node_type_to_id
}
"
if
graph
.
edge_type_to_id
is
not
None
:
meta_str
+=
f
", edge_type_to_id=
{
graph
.
edge_type_to_id
}
"
if
graph
.
node_attributes
is
not
None
:
meta_str
+=
f
", node_attributes=
{
graph
.
node_attributes
}
"
if
graph
.
edge_attributes
is
not
None
:
meta_str
+=
f
", edge_attributes=
{
graph
.
edge_attributes
}
"
...
...
@@ -1094,6 +1124,8 @@ def from_dglgraph(
# Assign edge type according to the order of CSC matrix.
type_per_edge
=
None
if
is_homogeneous
else
homo_g
.
edata
[
ETYPE
][
edge_ids
]
node_attributes
=
{}
edge_attributes
=
{}
if
include_original_edge_id
:
# Assign edge attributes according to the original eids mapping.
...
...
@@ -1107,6 +1139,7 @@ def from_dglgraph(
type_per_edge
,
node_type_to_id
,
edge_type_to_id
,
node_attributes
,
edge_attributes
,
),
)
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
View file @
e181ef15
...
...
@@ -126,12 +126,19 @@ def test_homo_graph(total_num_nodes, total_num_edges):
csc_indptr
,
indices
=
gbt
.
random_homo_graph
(
total_num_nodes
,
total_num_edges
)
node_attributes
=
{
"A1"
:
torch
.
arange
(
total_num_nodes
),
"A2"
:
torch
.
arange
(
total_num_nodes
),
}
edge_attributes
=
{
"A1"
:
torch
.
randn
(
total_num_edges
),
"A2"
:
torch
.
randn
(
total_num_edges
),
}
graph
=
gb
.
fused_csc_sampling_graph
(
csc_indptr
,
indices
,
edge_attributes
=
edge_attributes
csc_indptr
,
indices
,
node_attributes
=
node_attributes
,
edge_attributes
=
edge_attributes
,
)
assert
graph
.
total_num_nodes
==
total_num_nodes
...
...
@@ -140,6 +147,7 @@ def test_homo_graph(total_num_nodes, total_num_edges):
assert
torch
.
equal
(
csc_indptr
,
graph
.
csc_indptr
)
assert
torch
.
equal
(
indices
,
graph
.
indices
)
assert
graph
.
node_attributes
==
node_attributes
assert
graph
.
edge_attributes
==
edge_attributes
assert
graph
.
node_type_offset
is
None
assert
graph
.
type_per_edge
is
None
...
...
@@ -167,6 +175,10 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes):
)
=
gbt
.
random_hetero_graph
(
total_num_nodes
,
total_num_edges
,
num_ntypes
,
num_etypes
)
node_attributes
=
{
"A1"
:
torch
.
arange
(
total_num_nodes
),
"A2"
:
torch
.
arange
(
total_num_nodes
),
}
edge_attributes
=
{
"A1"
:
torch
.
randn
(
total_num_edges
),
"A2"
:
torch
.
randn
(
total_num_edges
),
...
...
@@ -178,6 +190,7 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes):
type_per_edge
=
type_per_edge
,
node_type_to_id
=
node_type_to_id
,
edge_type_to_id
=
edge_type_to_id
,
node_attributes
=
node_attributes
,
edge_attributes
=
edge_attributes
,
)
...
...
@@ -188,6 +201,7 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes):
assert
torch
.
equal
(
indices
,
graph
.
indices
)
assert
torch
.
equal
(
node_type_offset
,
graph
.
node_type_offset
)
assert
torch
.
equal
(
type_per_edge
,
graph
.
type_per_edge
)
assert
graph
.
node_attributes
==
node_attributes
assert
graph
.
edge_attributes
==
edge_attributes
assert
node_type_to_id
==
graph
.
node_type_to_id
assert
edge_type_to_id
==
graph
.
edge_type_to_id
...
...
@@ -327,11 +341,32 @@ def test_node_type_offset_wrong_legnth(node_type_offset):
"total_num_nodes, total_num_edges"
,
[(
1
,
1
),
(
100
,
1
),
(
10
,
50
),
(
1000
,
50000
)],
)
def
test_load_save_homo_graph
(
total_num_nodes
,
total_num_edges
):
@
pytest
.
mark
.
parametrize
(
"has_node_attrs"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"has_edge_attrs"
,
[
True
,
False
])
def
test_load_save_homo_graph
(
total_num_nodes
,
total_num_edges
,
has_node_attrs
,
has_edge_attrs
):
csc_indptr
,
indices
=
gbt
.
random_homo_graph
(
total_num_nodes
,
total_num_edges
)
graph
=
gb
.
fused_csc_sampling_graph
(
csc_indptr
,
indices
)
node_attributes
=
None
if
has_node_attrs
:
node_attributes
=
{
"A"
:
torch
.
arange
(
total_num_nodes
),
"B"
:
torch
.
arange
(
total_num_nodes
),
}
edge_attributes
=
None
if
has_edge_attrs
:
edge_attributes
=
{
"A"
:
torch
.
arange
(
total_num_edges
),
"B"
:
torch
.
arange
(
total_num_edges
),
}
graph
=
gb
.
fused_csc_sampling_graph
(
csc_indptr
,
indices
,
node_attributes
=
node_attributes
,
edge_attributes
=
edge_attributes
,
)
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
filename
=
os
.
path
.
join
(
test_dir
,
"fused_csc_sampling_graph.pt"
)
...
...
@@ -348,6 +383,21 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges):
assert
graph
.
type_per_edge
is
None
and
graph2
.
type_per_edge
is
None
assert
graph
.
node_type_to_id
is
None
and
graph2
.
node_type_to_id
is
None
assert
graph
.
edge_type_to_id
is
None
and
graph2
.
edge_type_to_id
is
None
if
has_node_attrs
:
assert
graph
.
node_attributes
.
keys
()
==
graph2
.
node_attributes
.
keys
()
for
key
in
graph
.
node_attributes
.
keys
():
assert
torch
.
equal
(
graph
.
node_attributes
[
key
],
graph2
.
node_attributes
[
key
]
)
else
:
assert
graph
.
node_attributes
is
None
and
graph2
.
node_attributes
is
None
if
has_edge_attrs
:
assert
graph
.
edge_attributes
.
keys
()
==
graph2
.
edge_attributes
.
keys
()
for
key
in
graph
.
edge_attributes
.
keys
():
assert
torch
.
equal
(
graph
.
edge_attributes
[
key
],
graph2
.
edge_attributes
[
key
]
)
else
:
assert
graph
.
edge_attributes
is
None
and
graph2
.
edge_attributes
is
None
...
...
@@ -360,8 +410,15 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges):
[(
1
,
1
),
(
100
,
1
),
(
10
,
50
),
(
1000
,
50000
)],
)
@
pytest
.
mark
.
parametrize
(
"num_ntypes, num_etypes"
,
[(
1
,
1
),
(
3
,
5
),
(
100
,
1
)])
@
pytest
.
mark
.
parametrize
(
"has_node_attrs"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"has_edge_attrs"
,
[
True
,
False
])
def
test_load_save_hetero_graph
(
total_num_nodes
,
total_num_edges
,
num_ntypes
,
num_etypes
total_num_nodes
,
total_num_edges
,
num_ntypes
,
num_etypes
,
has_node_attrs
,
has_edge_attrs
,
):
(
csc_indptr
,
...
...
@@ -373,6 +430,18 @@ def test_load_save_hetero_graph(
)
=
gbt
.
random_hetero_graph
(
total_num_nodes
,
total_num_edges
,
num_ntypes
,
num_etypes
)
node_attributes
=
None
if
has_node_attrs
:
node_attributes
=
{
"A"
:
torch
.
arange
(
total_num_nodes
),
"B"
:
torch
.
arange
(
total_num_nodes
),
}
edge_attributes
=
None
if
has_edge_attrs
:
edge_attributes
=
{
"A"
:
torch
.
arange
(
total_num_edges
),
"B"
:
torch
.
arange
(
total_num_edges
),
}
graph
=
gb
.
fused_csc_sampling_graph
(
csc_indptr
,
indices
,
...
...
@@ -380,6 +449,8 @@ def test_load_save_hetero_graph(
type_per_edge
=
type_per_edge
,
node_type_to_id
=
node_type_to_id
,
edge_type_to_id
=
edge_type_to_id
,
node_attributes
=
node_attributes
,
edge_attributes
=
edge_attributes
,
)
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
...
...
@@ -396,6 +467,22 @@ def test_load_save_hetero_graph(
assert
torch
.
equal
(
graph
.
type_per_edge
,
graph2
.
type_per_edge
)
assert
graph
.
node_type_to_id
==
graph2
.
node_type_to_id
assert
graph
.
edge_type_to_id
==
graph2
.
edge_type_to_id
if
has_node_attrs
:
assert
graph
.
node_attributes
.
keys
()
==
graph2
.
node_attributes
.
keys
()
for
key
in
graph
.
node_attributes
.
keys
():
assert
torch
.
equal
(
graph
.
node_attributes
[
key
],
graph2
.
node_attributes
[
key
]
)
else
:
assert
graph
.
node_attributes
is
None
and
graph2
.
node_attributes
is
None
if
has_edge_attrs
:
assert
graph
.
edge_attributes
.
keys
()
==
graph2
.
edge_attributes
.
keys
()
for
key
in
graph
.
edge_attributes
.
keys
():
assert
torch
.
equal
(
graph
.
edge_attributes
[
key
],
graph2
.
edge_attributes
[
key
]
)
else
:
assert
graph
.
edge_attributes
is
None
and
graph2
.
edge_attributes
is
None
@
unittest
.
skipIf
(
...
...
@@ -406,11 +493,32 @@ def test_load_save_hetero_graph(
"total_num_nodes, total_num_edges"
,
[(
1
,
1
),
(
100
,
1
),
(
10
,
50
),
(
1000
,
50000
)],
)
def
test_pickle_homo_graph
(
total_num_nodes
,
total_num_edges
):
@
pytest
.
mark
.
parametrize
(
"has_node_attrs"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"has_edge_attrs"
,
[
True
,
False
])
def
test_pickle_homo_graph
(
total_num_nodes
,
total_num_edges
,
has_node_attrs
,
has_edge_attrs
):
csc_indptr
,
indices
=
gbt
.
random_homo_graph
(
total_num_nodes
,
total_num_edges
)
graph
=
gb
.
fused_csc_sampling_graph
(
csc_indptr
,
indices
)
node_attributes
=
None
if
has_node_attrs
:
node_attributes
=
{
"A"
:
torch
.
arange
(
total_num_nodes
),
"B"
:
torch
.
arange
(
total_num_nodes
),
}
edge_attributes
=
None
if
has_edge_attrs
:
edge_attributes
=
{
"A"
:
torch
.
arange
(
total_num_edges
),
"B"
:
torch
.
arange
(
total_num_edges
),
}
graph
=
gb
.
fused_csc_sampling_graph
(
csc_indptr
,
indices
,
node_attributes
=
node_attributes
,
edge_attributes
=
edge_attributes
,
)
serialized
=
pickle
.
dumps
(
graph
)
graph2
=
pickle
.
loads
(
serialized
)
...
...
@@ -425,6 +533,21 @@ def test_pickle_homo_graph(total_num_nodes, total_num_edges):
assert
graph
.
type_per_edge
is
None
and
graph2
.
type_per_edge
is
None
assert
graph
.
node_type_to_id
is
None
and
graph2
.
node_type_to_id
is
None
assert
graph
.
edge_type_to_id
is
None
and
graph2
.
edge_type_to_id
is
None
if
has_node_attrs
:
assert
graph
.
node_attributes
.
keys
()
==
graph2
.
node_attributes
.
keys
()
for
key
in
graph
.
node_attributes
.
keys
():
assert
torch
.
equal
(
graph
.
node_attributes
[
key
],
graph2
.
node_attributes
[
key
]
)
else
:
assert
graph
.
node_attributes
is
None
and
graph2
.
node_attributes
is
None
if
has_edge_attrs
:
assert
graph
.
edge_attributes
.
keys
()
==
graph2
.
edge_attributes
.
keys
()
for
key
in
graph
.
edge_attributes
.
keys
():
assert
torch
.
equal
(
graph
.
edge_attributes
[
key
],
graph2
.
edge_attributes
[
key
]
)
else
:
assert
graph
.
edge_attributes
is
None
and
graph2
.
edge_attributes
is
None
...
...
@@ -437,8 +560,15 @@ def test_pickle_homo_graph(total_num_nodes, total_num_edges):
[(
1
,
1
),
(
100
,
1
),
(
10
,
50
),
(
1000
,
50000
)],
)
@
pytest
.
mark
.
parametrize
(
"num_ntypes, num_etypes"
,
[(
1
,
1
),
(
3
,
5
),
(
100
,
1
)])
@
pytest
.
mark
.
parametrize
(
"has_node_attrs"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"has_edge_attrs"
,
[
True
,
False
])
def
test_pickle_hetero_graph
(
total_num_nodes
,
total_num_edges
,
num_ntypes
,
num_etypes
total_num_nodes
,
total_num_edges
,
num_ntypes
,
num_etypes
,
has_node_attrs
,
has_edge_attrs
,
):
(
csc_indptr
,
...
...
@@ -450,9 +580,17 @@ def test_pickle_hetero_graph(
)
=
gbt
.
random_hetero_graph
(
total_num_nodes
,
total_num_edges
,
num_ntypes
,
num_etypes
)
node_attributes
=
None
if
has_node_attrs
:
node_attributes
=
{
"A"
:
torch
.
arange
(
total_num_nodes
),
"B"
:
torch
.
arange
(
total_num_nodes
),
}
edge_attributes
=
None
if
has_edge_attrs
:
edge_attributes
=
{
"a
"
:
torch
.
ran
dn
(
(
total_num_edges
,)
),
"b
"
:
torch
.
ran
dint
(
1
,
10
,
(
total_num_edges
,)
),
"A
"
:
torch
.
a
ran
ge
(
total_num_edges
),
"B
"
:
torch
.
a
ran
ge
(
total_num_edges
),
}
graph
=
gb
.
fused_csc_sampling_graph
(
csc_indptr
,
...
...
@@ -461,6 +599,7 @@ def test_pickle_hetero_graph(
type_per_edge
=
type_per_edge
,
node_type_to_id
=
node_type_to_id
,
edge_type_to_id
=
edge_type_to_id
,
node_attributes
=
node_attributes
,
edge_attributes
=
edge_attributes
,
)
...
...
@@ -480,9 +619,22 @@ def test_pickle_hetero_graph(
assert
graph
.
edge_type_to_id
.
keys
()
==
graph2
.
edge_type_to_id
.
keys
()
for
i
in
graph
.
edge_type_to_id
.
keys
():
assert
graph
.
edge_type_to_id
[
i
]
==
graph2
.
edge_type_to_id
[
i
]
if
has_node_attrs
:
assert
graph
.
node_attributes
.
keys
()
==
graph2
.
node_attributes
.
keys
()
for
key
in
graph
.
node_attributes
.
keys
():
assert
torch
.
equal
(
graph
.
node_attributes
[
key
],
graph2
.
node_attributes
[
key
]
)
else
:
assert
graph
.
node_attributes
is
None
and
graph2
.
node_attributes
is
None
if
has_edge_attrs
:
assert
graph
.
edge_attributes
.
keys
()
==
graph2
.
edge_attributes
.
keys
()
for
i
in
graph
.
edge_attributes
.
keys
():
assert
torch
.
equal
(
graph
.
edge_attributes
[
i
],
graph2
.
edge_attributes
[
i
])
for
key
in
graph
.
edge_attributes
.
keys
():
assert
torch
.
equal
(
graph
.
edge_attributes
[
key
],
graph2
.
edge_attributes
[
key
]
)
else
:
assert
graph
.
edge_attributes
is
None
and
graph2
.
edge_attributes
is
None
def
process_csc_sampling_graph_multiprocessing
(
graph
):
...
...
@@ -1258,6 +1410,18 @@ def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor):
t1
[:]
=
old_t1
def
check_node_edge_attributes
(
graph1
,
graph2
,
attributes
,
attr_name
):
for
name
,
attr
in
attributes
.
items
():
edge_attributes_1
=
getattr
(
graph1
,
attr_name
)
edge_attributes_2
=
getattr
(
graph2
,
attr_name
)
assert
name
in
edge_attributes_1
assert
name
in
edge_attributes_2
assert
torch
.
equal
(
edge_attributes_1
[
name
],
attr
)
check_tensors_on_the_same_shared_memory
(
edge_attributes_1
[
name
],
edge_attributes_2
[
name
]
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"gpu"
,
reason
=
"FusedCSCSamplingGraph is only supported on CPU."
,
...
...
@@ -1266,22 +1430,31 @@ def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor):
"total_num_nodes, total_num_edges"
,
[(
1
,
1
),
(
100
,
1
),
(
10
,
50
),
(
1000
,
50000
)],
)
@
pytest
.
mark
.
parametrize
(
"test_node_attrs"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"test_edge_attrs"
,
[
True
,
False
])
def
test_homo_graph_on_shared_memory
(
total_num_nodes
,
total_num_edges
,
test_edge_attrs
total_num_nodes
,
total_num_edges
,
test_node_attrs
,
test_edge_attrs
):
csc_indptr
,
indices
=
gbt
.
random_homo_graph
(
total_num_nodes
,
total_num_edges
)
node_attributes
=
None
if
test_node_attrs
:
node_attributes
=
{
"A1"
:
torch
.
arange
(
total_num_nodes
),
"A2"
:
torch
.
arange
(
total_num_nodes
),
}
edge_attributes
=
None
if
test_edge_attrs
:
edge_attributes
=
{
"A1"
:
torch
.
randn
(
total_num_edges
),
"A2"
:
torch
.
randn
(
total_num_edges
),
}
else
:
edge_attributes
=
None
graph
=
gb
.
fused_csc_sampling_graph
(
csc_indptr
,
indices
,
edge_attributes
=
edge_attributes
csc_indptr
,
indices
,
node_attributes
=
node_attributes
,
edge_attributes
=
edge_attributes
,
)
shm_name
=
"test_homo_g"
...
...
@@ -1307,13 +1480,13 @@ def test_homo_graph_on_shared_memory(
)
check_tensors_on_the_same_shared_memory
(
graph1
.
indices
,
graph2
.
indices
)
if
test_node_attrs
:
check_node_edge_attributes
(
graph1
,
graph2
,
node_attributes
,
"node_attributes"
)
if
test_edge_attrs
:
for
name
,
edge_attr
in
edge_attributes
.
items
():
assert
name
in
graph1
.
edge_attributes
assert
name
in
graph2
.
edge_attributes
assert
torch
.
equal
(
graph1
.
edge_attributes
[
name
],
edge_attr
)
check_tensors_on_the_same_shared_memory
(
graph1
.
edge_attributes
[
name
],
graph2
.
edge_attributes
[
name
]
check_node_edge_attributes
(
graph1
,
graph2
,
edge_attributes
,
"edge_attributes"
)
assert
graph1
.
node_type_offset
is
None
and
graph2
.
node_type_offset
is
None
...
...
@@ -1333,9 +1506,15 @@ def test_homo_graph_on_shared_memory(
@
pytest
.
mark
.
parametrize
(
"num_ntypes, num_etypes"
,
[(
1
,
1
),
(
3
,
5
),
(
100
,
1
),
(
1000
,
1000
)]
)
@
pytest
.
mark
.
parametrize
(
"test_node_attrs"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"test_edge_attrs"
,
[
True
,
False
])
def
test_hetero_graph_on_shared_memory
(
total_num_nodes
,
total_num_edges
,
num_ntypes
,
num_etypes
,
test_edge_attrs
total_num_nodes
,
total_num_edges
,
num_ntypes
,
num_etypes
,
test_node_attrs
,
test_edge_attrs
,
):
(
csc_indptr
,
...
...
@@ -1348,13 +1527,20 @@ def test_hetero_graph_on_shared_memory(
total_num_nodes
,
total_num_edges
,
num_ntypes
,
num_etypes
)
node_attributes
=
None
if
test_node_attrs
:
node_attributes
=
{
"A1"
:
torch
.
arange
(
total_num_nodes
),
"A2"
:
torch
.
arange
(
total_num_nodes
),
}
edge_attributes
=
None
if
test_edge_attrs
:
edge_attributes
=
{
"A1"
:
torch
.
randn
(
total_num_edges
),
"A2"
:
torch
.
randn
(
total_num_edges
),
}
else
:
edge_attributes
=
None
graph
=
gb
.
fused_csc_sampling_graph
(
csc_indptr
,
indices
,
...
...
@@ -1362,6 +1548,7 @@ def test_hetero_graph_on_shared_memory(
type_per_edge
=
type_per_edge
,
node_type_to_id
=
node_type_to_id
,
edge_type_to_id
=
edge_type_to_id
,
node_attributes
=
node_attributes
,
edge_attributes
=
edge_attributes
,
)
...
...
@@ -1398,13 +1585,13 @@ def test_hetero_graph_on_shared_memory(
graph1
.
type_per_edge
,
graph2
.
type_per_edge
)
if
test_node_attrs
:
check_node_edge_attributes
(
graph1
,
graph2
,
node_attributes
,
"node_attributes"
)
if
test_edge_attrs
:
for
name
,
edge_attr
in
edge_attributes
.
items
():
assert
name
in
graph1
.
edge_attributes
assert
name
in
graph2
.
edge_attributes
assert
torch
.
equal
(
graph1
.
edge_attributes
[
name
],
edge_attr
)
check_tensors_on_the_same_shared_memory
(
graph1
.
edge_attributes
[
name
],
graph2
.
edge_attributes
[
name
]
check_node_edge_attributes
(
graph1
,
graph2
,
edge_attributes
,
"edge_attributes"
)
assert
node_type_to_id
==
graph1
.
node_type_to_id
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment