"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "7e1b202d5ebf36475a8b7031ce3035af12883003"
Unverified Commit e181ef15 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] add node_attributes into FusedCSCSamplingGraph (#6757)

parent cad7caeb
......@@ -50,6 +50,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
public:
using NodeTypeToIDMap = torch::Dict<std::string, int64_t>;
using EdgeTypeToIDMap = torch::Dict<std::string, int64_t>;
using NodeAttrMap = torch::Dict<std::string, torch::Tensor>;
using EdgeAttrMap = torch::Dict<std::string, torch::Tensor>;
/** @brief Default constructor. */
FusedCSCSamplingGraph() = default;
......@@ -66,16 +67,18 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* present.
* @param edge_type_to_id A dictionary mapping edge type names to type IDs, if
* present.
* @param node_attributes A dictionary of node attributes, if present.
* @param edge_attributes A dictionary of edge attributes, if present.
*
*/
FusedCSCSamplingGraph(
const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<NodeTypeToIDMap>& node_type_to_id,
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id,
const torch::optional<EdgeAttrMap>& edge_attributes);
const torch::optional<torch::Tensor>& node_type_offset = torch::nullopt,
const torch::optional<torch::Tensor>& type_per_edge = torch::nullopt,
const torch::optional<NodeTypeToIDMap>& node_type_to_id = torch::nullopt,
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id = torch::nullopt,
const torch::optional<NodeAttrMap>& node_attributes = torch::nullopt,
const torch::optional<EdgeAttrMap>& edge_attributes = torch::nullopt);
/**
* @brief Create a fused CSC graph from tensors of CSC format.
......@@ -89,6 +92,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* present.
* @param edge_type_to_id A dictionary mapping edge type names to type IDs, if
* present.
* @param node_attributes A dictionary of node attributes, if present.
* @param edge_attributes A dictionary of edge attributes, if present.
*
* @return FusedCSCSamplingGraph
......@@ -99,6 +103,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<NodeTypeToIDMap>& node_type_to_id,
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id,
const torch::optional<NodeAttrMap>& node_attributes,
const torch::optional<EdgeAttrMap>& edge_attributes);
/** @brief Get the number of nodes. */
......@@ -139,6 +144,11 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
return edge_type_to_id_;
}
/** @brief Get the node attributes dictionary. */
inline const torch::optional<EdgeAttrMap> NodeAttributes() const {
return node_attributes_;
}
/** @brief Get the edge attributes dictionary. */
inline const torch::optional<EdgeAttrMap> EdgeAttributes() const {
return edge_attributes_;
......@@ -180,6 +190,12 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
edge_type_to_id_ = edge_type_to_id;
}
/** @brief Set the node attributes dictionary. */
inline void SetNodeAttributes(
const torch::optional<EdgeAttrMap>& node_attributes) {
node_attributes_ = node_attributes;
}
/** @brief Set the edge attributes dictionary. */
inline void SetEdgeAttributes(
const torch::optional<EdgeAttrMap>& edge_attributes) {
......@@ -367,6 +383,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
*/
torch::optional<EdgeTypeToIDMap> edge_type_to_id_;
/**
* @brief A dictionary of node attributes. Each key represents the attribute's
* name, while the corresponding value holds the attribute's specific value.
* The length of each value should match the total number of nodes."
*/
torch::optional<NodeAttrMap> node_attributes_;
/**
* @brief A dictionary of edge attributes. Each key represents the attribute's
* name, while the corresponding value holds the attribute's specific value.
......
......@@ -56,6 +56,7 @@ FusedCSCSamplingGraph::FusedCSCSamplingGraph(
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<NodeTypeToIDMap>& node_type_to_id,
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id,
const torch::optional<NodeAttrMap>& node_attributes,
const torch::optional<EdgeAttrMap>& edge_attributes)
: indptr_(indptr),
indices_(indices),
......@@ -63,6 +64,7 @@ FusedCSCSamplingGraph::FusedCSCSamplingGraph(
type_per_edge_(type_per_edge),
node_type_to_id_(node_type_to_id),
edge_type_to_id_(edge_type_to_id),
node_attributes_(node_attributes),
edge_attributes_(edge_attributes) {
TORCH_CHECK(indptr.dim() == 1);
TORCH_CHECK(indices.dim() == 1);
......@@ -75,6 +77,7 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::Create(
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<NodeTypeToIDMap>& node_type_to_id,
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id,
const torch::optional<NodeAttrMap>& node_attributes,
const torch::optional<EdgeAttrMap>& edge_attributes) {
if (node_type_offset.has_value()) {
auto& offset = node_type_offset.value();
......@@ -89,6 +92,11 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::Create(
TORCH_CHECK(type_per_edge.value().size(0) == indices.size(0));
TORCH_CHECK(edge_type_to_id.has_value());
}
if (node_attributes.has_value()) {
for (const auto& pair : node_attributes.value()) {
TORCH_CHECK(pair.value().size(0) == indptr.size(0) - 1);
}
}
if (edge_attributes.has_value()) {
for (const auto& pair : edge_attributes.value()) {
TORCH_CHECK(pair.value().size(0) == indices.size(0));
......@@ -96,7 +104,7 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::Create(
}
return c10::make_intrusive<FusedCSCSamplingGraph>(
indptr, indices, node_type_offset, type_per_edge, node_type_to_id,
edge_type_to_id, edge_attributes);
edge_type_to_id, node_attributes, edge_attributes);
}
void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
......@@ -150,6 +158,25 @@ void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
edge_type_to_id_ = std::move(edge_type_to_id);
}
// Optional node attributes.
torch::IValue has_node_attributes;
if (archive.try_read(
"FusedCSCSamplingGraph/has_node_attributes", has_node_attributes) &&
has_node_attributes.toBool()) {
torch::Dict<torch::IValue, torch::IValue> generic_dict =
read_from_archive(archive, "FusedCSCSamplingGraph/node_attributes")
.toGenericDict();
NodeAttrMap target_dict;
for (const auto& pair : generic_dict) {
std::string key = pair.key().toStringRef();
torch::Tensor value = pair.value().toTensor();
// Use move to avoid copy.
target_dict.insert(std::move(key), std::move(value));
}
// Same as above.
node_attributes_ = std::move(target_dict);
}
// Optional edge attributes.
torch::IValue has_edge_attributes;
if (archive.try_read(
......@@ -203,6 +230,13 @@ void FusedCSCSamplingGraph::Save(
archive.write(
"FusedCSCSamplingGraph/edge_type_to_id", edge_type_to_id_.value());
}
archive.write(
"FusedCSCSamplingGraph/has_node_attributes",
node_attributes_.has_value());
if (node_attributes_) {
archive.write(
"FusedCSCSamplingGraph/node_attributes", node_attributes_.value());
}
archive.write(
"FusedCSCSamplingGraph/has_edge_attributes",
edge_attributes_.has_value());
......@@ -238,6 +272,9 @@ void FusedCSCSamplingGraph::SetState(
if (state.find("edge_type_to_id") != state.end()) {
edge_type_to_id_ = DetensorizeDict(state.at("edge_type_to_id"));
}
if (state.find("node_attributes") != state.end()) {
node_attributes_ = state.at("node_attributes");
}
if (state.find("edge_attributes") != state.end()) {
edge_attributes_ = state.at("edge_attributes");
}
......@@ -268,6 +305,9 @@ FusedCSCSamplingGraph::GetState() const {
if (edge_type_to_id_.has_value()) {
state.insert("edge_type_to_id", TensorizeDict(edge_type_to_id_).value());
}
if (node_attributes_.has_value()) {
state.insert("node_attributes", node_attributes_.value());
}
if (edge_attributes_.has_value()) {
state.insert("edge_attributes", edge_attributes_.value());
}
......@@ -596,10 +636,11 @@ BuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) {
auto type_per_edge = helper.ReadTorchTensor();
auto node_type_to_id = DetensorizeDict(helper.ReadTorchTensorDict());
auto edge_type_to_id = DetensorizeDict(helper.ReadTorchTensorDict());
auto node_attributes = helper.ReadTorchTensorDict();
auto edge_attributes = helper.ReadTorchTensorDict();
auto graph = c10::make_intrusive<FusedCSCSamplingGraph>(
indptr.value(), indices.value(), node_type_offset, type_per_edge,
node_type_to_id, edge_type_to_id, edge_attributes);
node_type_to_id, edge_type_to_id, node_attributes, edge_attributes);
auto shared_memory = helper.ReleaseSharedMemory();
graph->HoldSharedMemoryObject(
std::move(shared_memory.first), std::move(shared_memory.second));
......@@ -616,6 +657,7 @@ FusedCSCSamplingGraph::CopyToSharedMemory(
helper.WriteTorchTensor(type_per_edge_);
helper.WriteTorchTensorDict(TensorizeDict(node_type_to_id_));
helper.WriteTorchTensorDict(TensorizeDict(edge_type_to_id_));
helper.WriteTorchTensorDict(node_attributes_);
helper.WriteTorchTensorDict(edge_attributes_);
helper.Flush();
return BuildGraphFromSharedMemoryHelper(std::move(helper));
......
......@@ -43,12 +43,7 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(
TORCH_CHECK(
c10::isIntegralType(indices.scalar_type(), false),
"IndexSelectCSC is not implemented to slice noninteger types yet.");
torch::optional<torch::Tensor> temp;
torch::optional<sampling::FusedCSCSamplingGraph::NodeTypeToIDMap> temp2;
torch::optional<sampling::FusedCSCSamplingGraph::EdgeTypeToIDMap> temp3;
torch::optional<sampling::FusedCSCSamplingGraph::EdgeAttrMap> temp4;
sampling::FusedCSCSamplingGraph g(
indptr, indices, temp, temp, temp2, temp3, temp4);
sampling::FusedCSCSamplingGraph g(indptr, indices);
const auto res = g.InSubgraph(nodes);
return std::make_tuple(res->indptr, res->indices);
}
......
......@@ -37,6 +37,7 @@ TORCH_LIBRARY(graphbolt, m) {
.def("type_per_edge", &FusedCSCSamplingGraph::TypePerEdge)
.def("node_type_to_id", &FusedCSCSamplingGraph::NodeTypeToID)
.def("edge_type_to_id", &FusedCSCSamplingGraph::EdgeTypeToID)
.def("node_attributes", &FusedCSCSamplingGraph::NodeAttributes)
.def("edge_attributes", &FusedCSCSamplingGraph::EdgeAttributes)
.def("set_csc_indptr", &FusedCSCSamplingGraph::SetCSCIndptr)
.def("set_indices", &FusedCSCSamplingGraph::SetIndices)
......@@ -44,6 +45,7 @@ TORCH_LIBRARY(graphbolt, m) {
.def("set_type_per_edge", &FusedCSCSamplingGraph::SetTypePerEdge)
.def("set_node_type_to_id", &FusedCSCSamplingGraph::SetNodeTypeToID)
.def("set_edge_type_to_id", &FusedCSCSamplingGraph::SetEdgeTypeToID)
.def("set_node_attributes", &FusedCSCSamplingGraph::SetNodeAttributes)
.def("set_edge_attributes", &FusedCSCSamplingGraph::SetEdgeAttributes)
.def("in_subgraph", &FusedCSCSamplingGraph::InSubgraph)
.def("sample_neighbors", &FusedCSCSamplingGraph::SampleNeighbors)
......
......@@ -279,6 +279,27 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""Sets the edge type to id dictionary if present."""
self._c_csc_graph.set_edge_type_to_id(edge_type_to_id)
@property
def node_attributes(self) -> Optional[Dict[str, torch.Tensor]]:
"""Returns the node attributes dictionary.
Returns
-------
Dict[str, torch.Tensor] or None
If present, returns a dictionary of node attributes. Each key
represents the attribute's name, while the corresponding value
holds the attribute's specific value. The length of each value
should match the total number of nodes."
"""
return self._c_csc_graph.node_attributes()
@node_attributes.setter
def node_attributes(
self, node_attributes: Optional[Dict[str, torch.Tensor]]
) -> None:
"""Sets the node attributes dictionary."""
self._c_csc_graph.set_node_attributes(node_attributes)
@property
def edge_attributes(self) -> Optional[Dict[str, torch.Tensor]]:
"""Returns the edge attributes dictionary.
......@@ -892,6 +913,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
self.type_per_edge = recursive_apply(
self.type_per_edge, lambda x: _to(x, device)
)
self.node_attributes = recursive_apply(
self.node_attributes, lambda x: _to(x, device)
)
self.edge_attributes = recursive_apply(
self.edge_attributes, lambda x: _to(x, device)
)
......@@ -906,6 +930,7 @@ def fused_csc_sampling_graph(
type_per_edge: Optional[torch.tensor] = None,
node_type_to_id: Optional[Dict[str, int]] = None,
edge_type_to_id: Optional[Dict[str, int]] = None,
node_attributes: Optional[Dict[str, torch.tensor]] = None,
edge_attributes: Optional[Dict[str, torch.tensor]] = None,
) -> FusedCSCSamplingGraph:
"""Create a FusedCSCSamplingGraph object from a CSC representation.
......@@ -926,6 +951,8 @@ def fused_csc_sampling_graph(
Map node types to ids, by default None.
edge_type_to_id : Optional[Dict[str, int]], optional
Map edge types to ids, by default None.
node_attributes: Optional[Dict[str, torch.tensor]], optional
Node attributes of the graph, by default None.
edge_attributes: Optional[Dict[str, torch.tensor]], optional
Edge attributes of the graph, by default None.
......@@ -946,7 +973,7 @@ def fused_csc_sampling_graph(
... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge,
... node_type_to_id=ntypes, edge_type_to_id=etypes,
... edge_attributes=None,)
... node_attributes=None, edge_attributes=None,)
>>> print(graph)
FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
indices=tensor([1, 3, 0, 1, 2, 0, 3]),
......@@ -997,6 +1024,7 @@ def fused_csc_sampling_graph(
type_per_edge,
node_type_to_id,
edge_type_to_id,
node_attributes,
edge_attributes,
),
)
......@@ -1037,6 +1065,8 @@ def _csc_sampling_graph_str(graph: FusedCSCSamplingGraph) -> str:
meta_str += f", node_type_to_id={graph.node_type_to_id}"
if graph.edge_type_to_id is not None:
meta_str += f", edge_type_to_id={graph.edge_type_to_id}"
if graph.node_attributes is not None:
meta_str += f", node_attributes={graph.node_attributes}"
if graph.edge_attributes is not None:
meta_str += f", edge_attributes={graph.edge_attributes}"
......@@ -1094,6 +1124,8 @@ def from_dglgraph(
# Assign edge type according to the order of CSC matrix.
type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][edge_ids]
node_attributes = {}
edge_attributes = {}
if include_original_edge_id:
# Assign edge attributes according to the original eids mapping.
......@@ -1107,6 +1139,7 @@ def from_dglgraph(
type_per_edge,
node_type_to_id,
edge_type_to_id,
node_attributes,
edge_attributes,
),
)
......@@ -126,12 +126,19 @@ def test_homo_graph(total_num_nodes, total_num_edges):
csc_indptr, indices = gbt.random_homo_graph(
total_num_nodes, total_num_edges
)
node_attributes = {
"A1": torch.arange(total_num_nodes),
"A2": torch.arange(total_num_nodes),
}
edge_attributes = {
"A1": torch.randn(total_num_edges),
"A2": torch.randn(total_num_edges),
}
graph = gb.fused_csc_sampling_graph(
csc_indptr, indices, edge_attributes=edge_attributes
csc_indptr,
indices,
node_attributes=node_attributes,
edge_attributes=edge_attributes,
)
assert graph.total_num_nodes == total_num_nodes
......@@ -140,6 +147,7 @@ def test_homo_graph(total_num_nodes, total_num_edges):
assert torch.equal(csc_indptr, graph.csc_indptr)
assert torch.equal(indices, graph.indices)
assert graph.node_attributes == node_attributes
assert graph.edge_attributes == edge_attributes
assert graph.node_type_offset is None
assert graph.type_per_edge is None
......@@ -167,6 +175,10 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes):
) = gbt.random_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes
)
node_attributes = {
"A1": torch.arange(total_num_nodes),
"A2": torch.arange(total_num_nodes),
}
edge_attributes = {
"A1": torch.randn(total_num_edges),
"A2": torch.randn(total_num_edges),
......@@ -178,6 +190,7 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes):
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
node_attributes=node_attributes,
edge_attributes=edge_attributes,
)
......@@ -188,6 +201,7 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes):
assert torch.equal(indices, graph.indices)
assert torch.equal(node_type_offset, graph.node_type_offset)
assert torch.equal(type_per_edge, graph.type_per_edge)
assert graph.node_attributes == node_attributes
assert graph.edge_attributes == edge_attributes
assert node_type_to_id == graph.node_type_to_id
assert edge_type_to_id == graph.edge_type_to_id
......@@ -327,11 +341,32 @@ def test_node_type_offset_wrong_legnth(node_type_offset):
"total_num_nodes, total_num_edges",
[(1, 1), (100, 1), (10, 50), (1000, 50000)],
)
def test_load_save_homo_graph(total_num_nodes, total_num_edges):
@pytest.mark.parametrize("has_node_attrs", [True, False])
@pytest.mark.parametrize("has_edge_attrs", [True, False])
def test_load_save_homo_graph(
total_num_nodes, total_num_edges, has_node_attrs, has_edge_attrs
):
csc_indptr, indices = gbt.random_homo_graph(
total_num_nodes, total_num_edges
)
graph = gb.fused_csc_sampling_graph(csc_indptr, indices)
node_attributes = None
if has_node_attrs:
node_attributes = {
"A": torch.arange(total_num_nodes),
"B": torch.arange(total_num_nodes),
}
edge_attributes = None
if has_edge_attrs:
edge_attributes = {
"A": torch.arange(total_num_edges),
"B": torch.arange(total_num_edges),
}
graph = gb.fused_csc_sampling_graph(
csc_indptr,
indices,
node_attributes=node_attributes,
edge_attributes=edge_attributes,
)
with tempfile.TemporaryDirectory() as test_dir:
filename = os.path.join(test_dir, "fused_csc_sampling_graph.pt")
......@@ -348,7 +383,22 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges):
assert graph.type_per_edge is None and graph2.type_per_edge is None
assert graph.node_type_to_id is None and graph2.node_type_to_id is None
assert graph.edge_type_to_id is None and graph2.edge_type_to_id is None
assert graph.edge_attributes is None and graph2.edge_attributes is None
if has_node_attrs:
assert graph.node_attributes.keys() == graph2.node_attributes.keys()
for key in graph.node_attributes.keys():
assert torch.equal(
graph.node_attributes[key], graph2.node_attributes[key]
)
else:
assert graph.node_attributes is None and graph2.node_attributes is None
if has_edge_attrs:
assert graph.edge_attributes.keys() == graph2.edge_attributes.keys()
for key in graph.edge_attributes.keys():
assert torch.equal(
graph.edge_attributes[key], graph2.edge_attributes[key]
)
else:
assert graph.edge_attributes is None and graph2.edge_attributes is None
@unittest.skipIf(
......@@ -360,8 +410,15 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges):
[(1, 1), (100, 1), (10, 50), (1000, 50000)],
)
@pytest.mark.parametrize("num_ntypes, num_etypes", [(1, 1), (3, 5), (100, 1)])
@pytest.mark.parametrize("has_node_attrs", [True, False])
@pytest.mark.parametrize("has_edge_attrs", [True, False])
def test_load_save_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes
total_num_nodes,
total_num_edges,
num_ntypes,
num_etypes,
has_node_attrs,
has_edge_attrs,
):
(
csc_indptr,
......@@ -373,6 +430,18 @@ def test_load_save_hetero_graph(
) = gbt.random_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes
)
node_attributes = None
if has_node_attrs:
node_attributes = {
"A": torch.arange(total_num_nodes),
"B": torch.arange(total_num_nodes),
}
edge_attributes = None
if has_edge_attrs:
edge_attributes = {
"A": torch.arange(total_num_edges),
"B": torch.arange(total_num_edges),
}
graph = gb.fused_csc_sampling_graph(
csc_indptr,
indices,
......@@ -380,6 +449,8 @@ def test_load_save_hetero_graph(
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
node_attributes=node_attributes,
edge_attributes=edge_attributes,
)
with tempfile.TemporaryDirectory() as test_dir:
......@@ -396,6 +467,22 @@ def test_load_save_hetero_graph(
assert torch.equal(graph.type_per_edge, graph2.type_per_edge)
assert graph.node_type_to_id == graph2.node_type_to_id
assert graph.edge_type_to_id == graph2.edge_type_to_id
if has_node_attrs:
assert graph.node_attributes.keys() == graph2.node_attributes.keys()
for key in graph.node_attributes.keys():
assert torch.equal(
graph.node_attributes[key], graph2.node_attributes[key]
)
else:
assert graph.node_attributes is None and graph2.node_attributes is None
if has_edge_attrs:
assert graph.edge_attributes.keys() == graph2.edge_attributes.keys()
for key in graph.edge_attributes.keys():
assert torch.equal(
graph.edge_attributes[key], graph2.edge_attributes[key]
)
else:
assert graph.edge_attributes is None and graph2.edge_attributes is None
@unittest.skipIf(
......@@ -406,11 +493,32 @@ def test_load_save_hetero_graph(
"total_num_nodes, total_num_edges",
[(1, 1), (100, 1), (10, 50), (1000, 50000)],
)
def test_pickle_homo_graph(total_num_nodes, total_num_edges):
@pytest.mark.parametrize("has_node_attrs", [True, False])
@pytest.mark.parametrize("has_edge_attrs", [True, False])
def test_pickle_homo_graph(
total_num_nodes, total_num_edges, has_node_attrs, has_edge_attrs
):
csc_indptr, indices = gbt.random_homo_graph(
total_num_nodes, total_num_edges
)
graph = gb.fused_csc_sampling_graph(csc_indptr, indices)
node_attributes = None
if has_node_attrs:
node_attributes = {
"A": torch.arange(total_num_nodes),
"B": torch.arange(total_num_nodes),
}
edge_attributes = None
if has_edge_attrs:
edge_attributes = {
"A": torch.arange(total_num_edges),
"B": torch.arange(total_num_edges),
}
graph = gb.fused_csc_sampling_graph(
csc_indptr,
indices,
node_attributes=node_attributes,
edge_attributes=edge_attributes,
)
serialized = pickle.dumps(graph)
graph2 = pickle.loads(serialized)
......@@ -425,7 +533,22 @@ def test_pickle_homo_graph(total_num_nodes, total_num_edges):
assert graph.type_per_edge is None and graph2.type_per_edge is None
assert graph.node_type_to_id is None and graph2.node_type_to_id is None
assert graph.edge_type_to_id is None and graph2.edge_type_to_id is None
assert graph.edge_attributes is None and graph2.edge_attributes is None
if has_node_attrs:
assert graph.node_attributes.keys() == graph2.node_attributes.keys()
for key in graph.node_attributes.keys():
assert torch.equal(
graph.node_attributes[key], graph2.node_attributes[key]
)
else:
assert graph.node_attributes is None and graph2.node_attributes is None
if has_edge_attrs:
assert graph.edge_attributes.keys() == graph2.edge_attributes.keys()
for key in graph.edge_attributes.keys():
assert torch.equal(
graph.edge_attributes[key], graph2.edge_attributes[key]
)
else:
assert graph.edge_attributes is None and graph2.edge_attributes is None
@unittest.skipIf(
......@@ -437,8 +560,15 @@ def test_pickle_homo_graph(total_num_nodes, total_num_edges):
[(1, 1), (100, 1), (10, 50), (1000, 50000)],
)
@pytest.mark.parametrize("num_ntypes, num_etypes", [(1, 1), (3, 5), (100, 1)])
@pytest.mark.parametrize("has_node_attrs", [True, False])
@pytest.mark.parametrize("has_edge_attrs", [True, False])
def test_pickle_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes
total_num_nodes,
total_num_edges,
num_ntypes,
num_etypes,
has_node_attrs,
has_edge_attrs,
):
(
csc_indptr,
......@@ -450,10 +580,18 @@ def test_pickle_hetero_graph(
) = gbt.random_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes
)
edge_attributes = {
"a": torch.randn((total_num_edges,)),
"b": torch.randint(1, 10, (total_num_edges,)),
}
node_attributes = None
if has_node_attrs:
node_attributes = {
"A": torch.arange(total_num_nodes),
"B": torch.arange(total_num_nodes),
}
edge_attributes = None
if has_edge_attrs:
edge_attributes = {
"A": torch.arange(total_num_edges),
"B": torch.arange(total_num_edges),
}
graph = gb.fused_csc_sampling_graph(
csc_indptr,
indices,
......@@ -461,6 +599,7 @@ def test_pickle_hetero_graph(
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
node_attributes=node_attributes,
edge_attributes=edge_attributes,
)
......@@ -480,9 +619,22 @@ def test_pickle_hetero_graph(
assert graph.edge_type_to_id.keys() == graph2.edge_type_to_id.keys()
for i in graph.edge_type_to_id.keys():
assert graph.edge_type_to_id[i] == graph2.edge_type_to_id[i]
assert graph.edge_attributes.keys() == graph2.edge_attributes.keys()
for i in graph.edge_attributes.keys():
assert torch.equal(graph.edge_attributes[i], graph2.edge_attributes[i])
if has_node_attrs:
assert graph.node_attributes.keys() == graph2.node_attributes.keys()
for key in graph.node_attributes.keys():
assert torch.equal(
graph.node_attributes[key], graph2.node_attributes[key]
)
else:
assert graph.node_attributes is None and graph2.node_attributes is None
if has_edge_attrs:
assert graph.edge_attributes.keys() == graph2.edge_attributes.keys()
for key in graph.edge_attributes.keys():
assert torch.equal(
graph.edge_attributes[key], graph2.edge_attributes[key]
)
else:
assert graph.edge_attributes is None and graph2.edge_attributes is None
def process_csc_sampling_graph_multiprocessing(graph):
......@@ -1258,6 +1410,18 @@ def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor):
t1[:] = old_t1
def check_node_edge_attributes(graph1, graph2, attributes, attr_name):
for name, attr in attributes.items():
edge_attributes_1 = getattr(graph1, attr_name)
edge_attributes_2 = getattr(graph2, attr_name)
assert name in edge_attributes_1
assert name in edge_attributes_2
assert torch.equal(edge_attributes_1[name], attr)
check_tensors_on_the_same_shared_memory(
edge_attributes_1[name], edge_attributes_2[name]
)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="FusedCSCSamplingGraph is only supported on CPU.",
......@@ -1266,22 +1430,31 @@ def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor):
"total_num_nodes, total_num_edges",
[(1, 1), (100, 1), (10, 50), (1000, 50000)],
)
@pytest.mark.parametrize("test_node_attrs", [True, False])
@pytest.mark.parametrize("test_edge_attrs", [True, False])
def test_homo_graph_on_shared_memory(
total_num_nodes, total_num_edges, test_edge_attrs
total_num_nodes, total_num_edges, test_node_attrs, test_edge_attrs
):
csc_indptr, indices = gbt.random_homo_graph(
total_num_nodes, total_num_edges
)
node_attributes = None
if test_node_attrs:
node_attributes = {
"A1": torch.arange(total_num_nodes),
"A2": torch.arange(total_num_nodes),
}
edge_attributes = None
if test_edge_attrs:
edge_attributes = {
"A1": torch.randn(total_num_edges),
"A2": torch.randn(total_num_edges),
}
else:
edge_attributes = None
graph = gb.fused_csc_sampling_graph(
csc_indptr, indices, edge_attributes=edge_attributes
csc_indptr,
indices,
node_attributes=node_attributes,
edge_attributes=edge_attributes,
)
shm_name = "test_homo_g"
......@@ -1307,14 +1480,14 @@ def test_homo_graph_on_shared_memory(
)
check_tensors_on_the_same_shared_memory(graph1.indices, graph2.indices)
if test_node_attrs:
check_node_edge_attributes(
graph1, graph2, node_attributes, "node_attributes"
)
if test_edge_attrs:
for name, edge_attr in edge_attributes.items():
assert name in graph1.edge_attributes
assert name in graph2.edge_attributes
assert torch.equal(graph1.edge_attributes[name], edge_attr)
check_tensors_on_the_same_shared_memory(
graph1.edge_attributes[name], graph2.edge_attributes[name]
)
check_node_edge_attributes(
graph1, graph2, edge_attributes, "edge_attributes"
)
assert graph1.node_type_offset is None and graph2.node_type_offset is None
assert graph1.type_per_edge is None and graph2.type_per_edge is None
......@@ -1333,9 +1506,15 @@ def test_homo_graph_on_shared_memory(
@pytest.mark.parametrize(
"num_ntypes, num_etypes", [(1, 1), (3, 5), (100, 1), (1000, 1000)]
)
@pytest.mark.parametrize("test_node_attrs", [True, False])
@pytest.mark.parametrize("test_edge_attrs", [True, False])
def test_hetero_graph_on_shared_memory(
total_num_nodes, total_num_edges, num_ntypes, num_etypes, test_edge_attrs
total_num_nodes,
total_num_edges,
num_ntypes,
num_etypes,
test_node_attrs,
test_edge_attrs,
):
(
csc_indptr,
......@@ -1348,13 +1527,20 @@ def test_hetero_graph_on_shared_memory(
total_num_nodes, total_num_edges, num_ntypes, num_etypes
)
node_attributes = None
if test_node_attrs:
node_attributes = {
"A1": torch.arange(total_num_nodes),
"A2": torch.arange(total_num_nodes),
}
edge_attributes = None
if test_edge_attrs:
edge_attributes = {
"A1": torch.randn(total_num_edges),
"A2": torch.randn(total_num_edges),
}
else:
edge_attributes = None
graph = gb.fused_csc_sampling_graph(
csc_indptr,
indices,
......@@ -1362,6 +1548,7 @@ def test_hetero_graph_on_shared_memory(
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
node_attributes=node_attributes,
edge_attributes=edge_attributes,
)
......@@ -1398,14 +1585,14 @@ def test_hetero_graph_on_shared_memory(
graph1.type_per_edge, graph2.type_per_edge
)
if test_node_attrs:
check_node_edge_attributes(
graph1, graph2, node_attributes, "node_attributes"
)
if test_edge_attrs:
for name, edge_attr in edge_attributes.items():
assert name in graph1.edge_attributes
assert name in graph2.edge_attributes
assert torch.equal(graph1.edge_attributes[name], edge_attr)
check_tensors_on_the_same_shared_memory(
graph1.edge_attributes[name], graph2.edge_attributes[name]
)
check_node_edge_attributes(
graph1, graph2, edge_attributes, "edge_attributes"
)
assert node_type_to_id == graph1.node_type_to_id
assert edge_type_to_id == graph1.edge_type_to_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment