Unverified Commit e181ef15 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] add node_attributes into FusedCSCSamplingGraph (#6757)

parent cad7caeb
...@@ -50,6 +50,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -50,6 +50,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
public: public:
using NodeTypeToIDMap = torch::Dict<std::string, int64_t>; using NodeTypeToIDMap = torch::Dict<std::string, int64_t>;
using EdgeTypeToIDMap = torch::Dict<std::string, int64_t>; using EdgeTypeToIDMap = torch::Dict<std::string, int64_t>;
using NodeAttrMap = torch::Dict<std::string, torch::Tensor>;
using EdgeAttrMap = torch::Dict<std::string, torch::Tensor>; using EdgeAttrMap = torch::Dict<std::string, torch::Tensor>;
/** @brief Default constructor. */ /** @brief Default constructor. */
FusedCSCSamplingGraph() = default; FusedCSCSamplingGraph() = default;
...@@ -66,16 +67,18 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -66,16 +67,18 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* present. * present.
* @param edge_type_to_id A dictionary mapping edge type names to type IDs, if * @param edge_type_to_id A dictionary mapping edge type names to type IDs, if
* present. * present.
* @param node_attributes A dictionary of node attributes, if present.
* @param edge_attributes A dictionary of edge attributes, if present. * @param edge_attributes A dictionary of edge attributes, if present.
* *
*/ */
FusedCSCSamplingGraph( FusedCSCSamplingGraph(
const torch::Tensor& indptr, const torch::Tensor& indices, const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset, const torch::optional<torch::Tensor>& node_type_offset = torch::nullopt,
const torch::optional<torch::Tensor>& type_per_edge, const torch::optional<torch::Tensor>& type_per_edge = torch::nullopt,
const torch::optional<NodeTypeToIDMap>& node_type_to_id, const torch::optional<NodeTypeToIDMap>& node_type_to_id = torch::nullopt,
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id, const torch::optional<EdgeTypeToIDMap>& edge_type_to_id = torch::nullopt,
const torch::optional<EdgeAttrMap>& edge_attributes); const torch::optional<NodeAttrMap>& node_attributes = torch::nullopt,
const torch::optional<EdgeAttrMap>& edge_attributes = torch::nullopt);
/** /**
* @brief Create a fused CSC graph from tensors of CSC format. * @brief Create a fused CSC graph from tensors of CSC format.
...@@ -89,6 +92,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -89,6 +92,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* present. * present.
* @param edge_type_to_id A dictionary mapping edge type names to type IDs, if * @param edge_type_to_id A dictionary mapping edge type names to type IDs, if
* present. * present.
* @param node_attributes A dictionary of node attributes, if present.
* @param edge_attributes A dictionary of edge attributes, if present. * @param edge_attributes A dictionary of edge attributes, if present.
* *
* @return FusedCSCSamplingGraph * @return FusedCSCSamplingGraph
...@@ -99,6 +103,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -99,6 +103,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
const torch::optional<torch::Tensor>& type_per_edge, const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<NodeTypeToIDMap>& node_type_to_id, const torch::optional<NodeTypeToIDMap>& node_type_to_id,
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id, const torch::optional<EdgeTypeToIDMap>& edge_type_to_id,
const torch::optional<NodeAttrMap>& node_attributes,
const torch::optional<EdgeAttrMap>& edge_attributes); const torch::optional<EdgeAttrMap>& edge_attributes);
/** @brief Get the number of nodes. */ /** @brief Get the number of nodes. */
...@@ -139,6 +144,11 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -139,6 +144,11 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
return edge_type_to_id_; return edge_type_to_id_;
} }
/** @brief Get the node attributes dictionary. */
inline const torch::optional<EdgeAttrMap> NodeAttributes() const {
return node_attributes_;
}
/** @brief Get the edge attributes dictionary. */ /** @brief Get the edge attributes dictionary. */
inline const torch::optional<EdgeAttrMap> EdgeAttributes() const { inline const torch::optional<EdgeAttrMap> EdgeAttributes() const {
return edge_attributes_; return edge_attributes_;
...@@ -180,6 +190,12 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -180,6 +190,12 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
edge_type_to_id_ = edge_type_to_id; edge_type_to_id_ = edge_type_to_id;
} }
/** @brief Set the node attributes dictionary. */
inline void SetNodeAttributes(
const torch::optional<EdgeAttrMap>& node_attributes) {
node_attributes_ = node_attributes;
}
/** @brief Set the edge attributes dictionary. */ /** @brief Set the edge attributes dictionary. */
inline void SetEdgeAttributes( inline void SetEdgeAttributes(
const torch::optional<EdgeAttrMap>& edge_attributes) { const torch::optional<EdgeAttrMap>& edge_attributes) {
...@@ -367,6 +383,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -367,6 +383,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
*/ */
torch::optional<EdgeTypeToIDMap> edge_type_to_id_; torch::optional<EdgeTypeToIDMap> edge_type_to_id_;
/**
* @brief A dictionary of node attributes. Each key represents the attribute's
* name, while the corresponding value holds the attribute's specific value.
* The length of each value should match the total number of nodes."
*/
torch::optional<NodeAttrMap> node_attributes_;
/** /**
* @brief A dictionary of edge attributes. Each key represents the attribute's * @brief A dictionary of edge attributes. Each key represents the attribute's
* name, while the corresponding value holds the attribute's specific value. * name, while the corresponding value holds the attribute's specific value.
......
...@@ -56,6 +56,7 @@ FusedCSCSamplingGraph::FusedCSCSamplingGraph( ...@@ -56,6 +56,7 @@ FusedCSCSamplingGraph::FusedCSCSamplingGraph(
const torch::optional<torch::Tensor>& type_per_edge, const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<NodeTypeToIDMap>& node_type_to_id, const torch::optional<NodeTypeToIDMap>& node_type_to_id,
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id, const torch::optional<EdgeTypeToIDMap>& edge_type_to_id,
const torch::optional<NodeAttrMap>& node_attributes,
const torch::optional<EdgeAttrMap>& edge_attributes) const torch::optional<EdgeAttrMap>& edge_attributes)
: indptr_(indptr), : indptr_(indptr),
indices_(indices), indices_(indices),
...@@ -63,6 +64,7 @@ FusedCSCSamplingGraph::FusedCSCSamplingGraph( ...@@ -63,6 +64,7 @@ FusedCSCSamplingGraph::FusedCSCSamplingGraph(
type_per_edge_(type_per_edge), type_per_edge_(type_per_edge),
node_type_to_id_(node_type_to_id), node_type_to_id_(node_type_to_id),
edge_type_to_id_(edge_type_to_id), edge_type_to_id_(edge_type_to_id),
node_attributes_(node_attributes),
edge_attributes_(edge_attributes) { edge_attributes_(edge_attributes) {
TORCH_CHECK(indptr.dim() == 1); TORCH_CHECK(indptr.dim() == 1);
TORCH_CHECK(indices.dim() == 1); TORCH_CHECK(indices.dim() == 1);
...@@ -75,6 +77,7 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::Create( ...@@ -75,6 +77,7 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::Create(
const torch::optional<torch::Tensor>& type_per_edge, const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<NodeTypeToIDMap>& node_type_to_id, const torch::optional<NodeTypeToIDMap>& node_type_to_id,
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id, const torch::optional<EdgeTypeToIDMap>& edge_type_to_id,
const torch::optional<NodeAttrMap>& node_attributes,
const torch::optional<EdgeAttrMap>& edge_attributes) { const torch::optional<EdgeAttrMap>& edge_attributes) {
if (node_type_offset.has_value()) { if (node_type_offset.has_value()) {
auto& offset = node_type_offset.value(); auto& offset = node_type_offset.value();
...@@ -89,6 +92,11 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::Create( ...@@ -89,6 +92,11 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::Create(
TORCH_CHECK(type_per_edge.value().size(0) == indices.size(0)); TORCH_CHECK(type_per_edge.value().size(0) == indices.size(0));
TORCH_CHECK(edge_type_to_id.has_value()); TORCH_CHECK(edge_type_to_id.has_value());
} }
if (node_attributes.has_value()) {
for (const auto& pair : node_attributes.value()) {
TORCH_CHECK(pair.value().size(0) == indptr.size(0) - 1);
}
}
if (edge_attributes.has_value()) { if (edge_attributes.has_value()) {
for (const auto& pair : edge_attributes.value()) { for (const auto& pair : edge_attributes.value()) {
TORCH_CHECK(pair.value().size(0) == indices.size(0)); TORCH_CHECK(pair.value().size(0) == indices.size(0));
...@@ -96,7 +104,7 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::Create( ...@@ -96,7 +104,7 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::Create(
} }
return c10::make_intrusive<FusedCSCSamplingGraph>( return c10::make_intrusive<FusedCSCSamplingGraph>(
indptr, indices, node_type_offset, type_per_edge, node_type_to_id, indptr, indices, node_type_offset, type_per_edge, node_type_to_id,
edge_type_to_id, edge_attributes); edge_type_to_id, node_attributes, edge_attributes);
} }
void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) { void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
...@@ -150,6 +158,25 @@ void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) { ...@@ -150,6 +158,25 @@ void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
edge_type_to_id_ = std::move(edge_type_to_id); edge_type_to_id_ = std::move(edge_type_to_id);
} }
// Optional node attributes.
torch::IValue has_node_attributes;
if (archive.try_read(
"FusedCSCSamplingGraph/has_node_attributes", has_node_attributes) &&
has_node_attributes.toBool()) {
torch::Dict<torch::IValue, torch::IValue> generic_dict =
read_from_archive(archive, "FusedCSCSamplingGraph/node_attributes")
.toGenericDict();
NodeAttrMap target_dict;
for (const auto& pair : generic_dict) {
std::string key = pair.key().toStringRef();
torch::Tensor value = pair.value().toTensor();
// Use move to avoid copy.
target_dict.insert(std::move(key), std::move(value));
}
// Same as above.
node_attributes_ = std::move(target_dict);
}
// Optional edge attributes. // Optional edge attributes.
torch::IValue has_edge_attributes; torch::IValue has_edge_attributes;
if (archive.try_read( if (archive.try_read(
...@@ -203,6 +230,13 @@ void FusedCSCSamplingGraph::Save( ...@@ -203,6 +230,13 @@ void FusedCSCSamplingGraph::Save(
archive.write( archive.write(
"FusedCSCSamplingGraph/edge_type_to_id", edge_type_to_id_.value()); "FusedCSCSamplingGraph/edge_type_to_id", edge_type_to_id_.value());
} }
archive.write(
"FusedCSCSamplingGraph/has_node_attributes",
node_attributes_.has_value());
if (node_attributes_) {
archive.write(
"FusedCSCSamplingGraph/node_attributes", node_attributes_.value());
}
archive.write( archive.write(
"FusedCSCSamplingGraph/has_edge_attributes", "FusedCSCSamplingGraph/has_edge_attributes",
edge_attributes_.has_value()); edge_attributes_.has_value());
...@@ -238,6 +272,9 @@ void FusedCSCSamplingGraph::SetState( ...@@ -238,6 +272,9 @@ void FusedCSCSamplingGraph::SetState(
if (state.find("edge_type_to_id") != state.end()) { if (state.find("edge_type_to_id") != state.end()) {
edge_type_to_id_ = DetensorizeDict(state.at("edge_type_to_id")); edge_type_to_id_ = DetensorizeDict(state.at("edge_type_to_id"));
} }
if (state.find("node_attributes") != state.end()) {
node_attributes_ = state.at("node_attributes");
}
if (state.find("edge_attributes") != state.end()) { if (state.find("edge_attributes") != state.end()) {
edge_attributes_ = state.at("edge_attributes"); edge_attributes_ = state.at("edge_attributes");
} }
...@@ -268,6 +305,9 @@ FusedCSCSamplingGraph::GetState() const { ...@@ -268,6 +305,9 @@ FusedCSCSamplingGraph::GetState() const {
if (edge_type_to_id_.has_value()) { if (edge_type_to_id_.has_value()) {
state.insert("edge_type_to_id", TensorizeDict(edge_type_to_id_).value()); state.insert("edge_type_to_id", TensorizeDict(edge_type_to_id_).value());
} }
if (node_attributes_.has_value()) {
state.insert("node_attributes", node_attributes_.value());
}
if (edge_attributes_.has_value()) { if (edge_attributes_.has_value()) {
state.insert("edge_attributes", edge_attributes_.value()); state.insert("edge_attributes", edge_attributes_.value());
} }
...@@ -596,10 +636,11 @@ BuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) { ...@@ -596,10 +636,11 @@ BuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) {
auto type_per_edge = helper.ReadTorchTensor(); auto type_per_edge = helper.ReadTorchTensor();
auto node_type_to_id = DetensorizeDict(helper.ReadTorchTensorDict()); auto node_type_to_id = DetensorizeDict(helper.ReadTorchTensorDict());
auto edge_type_to_id = DetensorizeDict(helper.ReadTorchTensorDict()); auto edge_type_to_id = DetensorizeDict(helper.ReadTorchTensorDict());
auto node_attributes = helper.ReadTorchTensorDict();
auto edge_attributes = helper.ReadTorchTensorDict(); auto edge_attributes = helper.ReadTorchTensorDict();
auto graph = c10::make_intrusive<FusedCSCSamplingGraph>( auto graph = c10::make_intrusive<FusedCSCSamplingGraph>(
indptr.value(), indices.value(), node_type_offset, type_per_edge, indptr.value(), indices.value(), node_type_offset, type_per_edge,
node_type_to_id, edge_type_to_id, edge_attributes); node_type_to_id, edge_type_to_id, node_attributes, edge_attributes);
auto shared_memory = helper.ReleaseSharedMemory(); auto shared_memory = helper.ReleaseSharedMemory();
graph->HoldSharedMemoryObject( graph->HoldSharedMemoryObject(
std::move(shared_memory.first), std::move(shared_memory.second)); std::move(shared_memory.first), std::move(shared_memory.second));
...@@ -616,6 +657,7 @@ FusedCSCSamplingGraph::CopyToSharedMemory( ...@@ -616,6 +657,7 @@ FusedCSCSamplingGraph::CopyToSharedMemory(
helper.WriteTorchTensor(type_per_edge_); helper.WriteTorchTensor(type_per_edge_);
helper.WriteTorchTensorDict(TensorizeDict(node_type_to_id_)); helper.WriteTorchTensorDict(TensorizeDict(node_type_to_id_));
helper.WriteTorchTensorDict(TensorizeDict(edge_type_to_id_)); helper.WriteTorchTensorDict(TensorizeDict(edge_type_to_id_));
helper.WriteTorchTensorDict(node_attributes_);
helper.WriteTorchTensorDict(edge_attributes_); helper.WriteTorchTensorDict(edge_attributes_);
helper.Flush(); helper.Flush();
return BuildGraphFromSharedMemoryHelper(std::move(helper)); return BuildGraphFromSharedMemoryHelper(std::move(helper));
......
...@@ -43,12 +43,7 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC( ...@@ -43,12 +43,7 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(
TORCH_CHECK( TORCH_CHECK(
c10::isIntegralType(indices.scalar_type(), false), c10::isIntegralType(indices.scalar_type(), false),
"IndexSelectCSC is not implemented to slice noninteger types yet."); "IndexSelectCSC is not implemented to slice noninteger types yet.");
torch::optional<torch::Tensor> temp; sampling::FusedCSCSamplingGraph g(indptr, indices);
torch::optional<sampling::FusedCSCSamplingGraph::NodeTypeToIDMap> temp2;
torch::optional<sampling::FusedCSCSamplingGraph::EdgeTypeToIDMap> temp3;
torch::optional<sampling::FusedCSCSamplingGraph::EdgeAttrMap> temp4;
sampling::FusedCSCSamplingGraph g(
indptr, indices, temp, temp, temp2, temp3, temp4);
const auto res = g.InSubgraph(nodes); const auto res = g.InSubgraph(nodes);
return std::make_tuple(res->indptr, res->indices); return std::make_tuple(res->indptr, res->indices);
} }
......
...@@ -37,6 +37,7 @@ TORCH_LIBRARY(graphbolt, m) { ...@@ -37,6 +37,7 @@ TORCH_LIBRARY(graphbolt, m) {
.def("type_per_edge", &FusedCSCSamplingGraph::TypePerEdge) .def("type_per_edge", &FusedCSCSamplingGraph::TypePerEdge)
.def("node_type_to_id", &FusedCSCSamplingGraph::NodeTypeToID) .def("node_type_to_id", &FusedCSCSamplingGraph::NodeTypeToID)
.def("edge_type_to_id", &FusedCSCSamplingGraph::EdgeTypeToID) .def("edge_type_to_id", &FusedCSCSamplingGraph::EdgeTypeToID)
.def("node_attributes", &FusedCSCSamplingGraph::NodeAttributes)
.def("edge_attributes", &FusedCSCSamplingGraph::EdgeAttributes) .def("edge_attributes", &FusedCSCSamplingGraph::EdgeAttributes)
.def("set_csc_indptr", &FusedCSCSamplingGraph::SetCSCIndptr) .def("set_csc_indptr", &FusedCSCSamplingGraph::SetCSCIndptr)
.def("set_indices", &FusedCSCSamplingGraph::SetIndices) .def("set_indices", &FusedCSCSamplingGraph::SetIndices)
...@@ -44,6 +45,7 @@ TORCH_LIBRARY(graphbolt, m) { ...@@ -44,6 +45,7 @@ TORCH_LIBRARY(graphbolt, m) {
.def("set_type_per_edge", &FusedCSCSamplingGraph::SetTypePerEdge) .def("set_type_per_edge", &FusedCSCSamplingGraph::SetTypePerEdge)
.def("set_node_type_to_id", &FusedCSCSamplingGraph::SetNodeTypeToID) .def("set_node_type_to_id", &FusedCSCSamplingGraph::SetNodeTypeToID)
.def("set_edge_type_to_id", &FusedCSCSamplingGraph::SetEdgeTypeToID) .def("set_edge_type_to_id", &FusedCSCSamplingGraph::SetEdgeTypeToID)
.def("set_node_attributes", &FusedCSCSamplingGraph::SetNodeAttributes)
.def("set_edge_attributes", &FusedCSCSamplingGraph::SetEdgeAttributes) .def("set_edge_attributes", &FusedCSCSamplingGraph::SetEdgeAttributes)
.def("in_subgraph", &FusedCSCSamplingGraph::InSubgraph) .def("in_subgraph", &FusedCSCSamplingGraph::InSubgraph)
.def("sample_neighbors", &FusedCSCSamplingGraph::SampleNeighbors) .def("sample_neighbors", &FusedCSCSamplingGraph::SampleNeighbors)
......
...@@ -279,6 +279,27 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -279,6 +279,27 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""Sets the edge type to id dictionary if present.""" """Sets the edge type to id dictionary if present."""
self._c_csc_graph.set_edge_type_to_id(edge_type_to_id) self._c_csc_graph.set_edge_type_to_id(edge_type_to_id)
@property
def node_attributes(self) -> Optional[Dict[str, torch.Tensor]]:
"""Returns the node attributes dictionary.
Returns
-------
Dict[str, torch.Tensor] or None
If present, returns a dictionary of node attributes. Each key
represents the attribute's name, while the corresponding value
holds the attribute's specific value. The length of each value
should match the total number of nodes."
"""
return self._c_csc_graph.node_attributes()
@node_attributes.setter
def node_attributes(
self, node_attributes: Optional[Dict[str, torch.Tensor]]
) -> None:
"""Sets the node attributes dictionary."""
self._c_csc_graph.set_node_attributes(node_attributes)
@property @property
def edge_attributes(self) -> Optional[Dict[str, torch.Tensor]]: def edge_attributes(self) -> Optional[Dict[str, torch.Tensor]]:
"""Returns the edge attributes dictionary. """Returns the edge attributes dictionary.
...@@ -892,6 +913,9 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -892,6 +913,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
self.type_per_edge = recursive_apply( self.type_per_edge = recursive_apply(
self.type_per_edge, lambda x: _to(x, device) self.type_per_edge, lambda x: _to(x, device)
) )
self.node_attributes = recursive_apply(
self.node_attributes, lambda x: _to(x, device)
)
self.edge_attributes = recursive_apply( self.edge_attributes = recursive_apply(
self.edge_attributes, lambda x: _to(x, device) self.edge_attributes, lambda x: _to(x, device)
) )
...@@ -906,6 +930,7 @@ def fused_csc_sampling_graph( ...@@ -906,6 +930,7 @@ def fused_csc_sampling_graph(
type_per_edge: Optional[torch.tensor] = None, type_per_edge: Optional[torch.tensor] = None,
node_type_to_id: Optional[Dict[str, int]] = None, node_type_to_id: Optional[Dict[str, int]] = None,
edge_type_to_id: Optional[Dict[str, int]] = None, edge_type_to_id: Optional[Dict[str, int]] = None,
node_attributes: Optional[Dict[str, torch.tensor]] = None,
edge_attributes: Optional[Dict[str, torch.tensor]] = None, edge_attributes: Optional[Dict[str, torch.tensor]] = None,
) -> FusedCSCSamplingGraph: ) -> FusedCSCSamplingGraph:
"""Create a FusedCSCSamplingGraph object from a CSC representation. """Create a FusedCSCSamplingGraph object from a CSC representation.
...@@ -926,6 +951,8 @@ def fused_csc_sampling_graph( ...@@ -926,6 +951,8 @@ def fused_csc_sampling_graph(
Map node types to ids, by default None. Map node types to ids, by default None.
edge_type_to_id : Optional[Dict[str, int]], optional edge_type_to_id : Optional[Dict[str, int]], optional
Map edge types to ids, by default None. Map edge types to ids, by default None.
node_attributes: Optional[Dict[str, torch.tensor]], optional
Node attributes of the graph, by default None.
edge_attributes: Optional[Dict[str, torch.tensor]], optional edge_attributes: Optional[Dict[str, torch.tensor]], optional
Edge attributes of the graph, by default None. Edge attributes of the graph, by default None.
...@@ -946,7 +973,7 @@ def fused_csc_sampling_graph( ...@@ -946,7 +973,7 @@ def fused_csc_sampling_graph(
... node_type_offset=node_type_offset, ... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge, ... type_per_edge=type_per_edge,
... node_type_to_id=ntypes, edge_type_to_id=etypes, ... node_type_to_id=ntypes, edge_type_to_id=etypes,
... edge_attributes=None,) ... node_attributes=None, edge_attributes=None,)
>>> print(graph) >>> print(graph)
FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]), FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
indices=tensor([1, 3, 0, 1, 2, 0, 3]), indices=tensor([1, 3, 0, 1, 2, 0, 3]),
...@@ -997,6 +1024,7 @@ def fused_csc_sampling_graph( ...@@ -997,6 +1024,7 @@ def fused_csc_sampling_graph(
type_per_edge, type_per_edge,
node_type_to_id, node_type_to_id,
edge_type_to_id, edge_type_to_id,
node_attributes,
edge_attributes, edge_attributes,
), ),
) )
...@@ -1037,6 +1065,8 @@ def _csc_sampling_graph_str(graph: FusedCSCSamplingGraph) -> str: ...@@ -1037,6 +1065,8 @@ def _csc_sampling_graph_str(graph: FusedCSCSamplingGraph) -> str:
meta_str += f", node_type_to_id={graph.node_type_to_id}" meta_str += f", node_type_to_id={graph.node_type_to_id}"
if graph.edge_type_to_id is not None: if graph.edge_type_to_id is not None:
meta_str += f", edge_type_to_id={graph.edge_type_to_id}" meta_str += f", edge_type_to_id={graph.edge_type_to_id}"
if graph.node_attributes is not None:
meta_str += f", node_attributes={graph.node_attributes}"
if graph.edge_attributes is not None: if graph.edge_attributes is not None:
meta_str += f", edge_attributes={graph.edge_attributes}" meta_str += f", edge_attributes={graph.edge_attributes}"
...@@ -1094,6 +1124,8 @@ def from_dglgraph( ...@@ -1094,6 +1124,8 @@ def from_dglgraph(
# Assign edge type according to the order of CSC matrix. # Assign edge type according to the order of CSC matrix.
type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][edge_ids] type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][edge_ids]
node_attributes = {}
edge_attributes = {} edge_attributes = {}
if include_original_edge_id: if include_original_edge_id:
# Assign edge attributes according to the original eids mapping. # Assign edge attributes according to the original eids mapping.
...@@ -1107,6 +1139,7 @@ def from_dglgraph( ...@@ -1107,6 +1139,7 @@ def from_dglgraph(
type_per_edge, type_per_edge,
node_type_to_id, node_type_to_id,
edge_type_to_id, edge_type_to_id,
node_attributes,
edge_attributes, edge_attributes,
), ),
) )
...@@ -126,12 +126,19 @@ def test_homo_graph(total_num_nodes, total_num_edges): ...@@ -126,12 +126,19 @@ def test_homo_graph(total_num_nodes, total_num_edges):
csc_indptr, indices = gbt.random_homo_graph( csc_indptr, indices = gbt.random_homo_graph(
total_num_nodes, total_num_edges total_num_nodes, total_num_edges
) )
node_attributes = {
"A1": torch.arange(total_num_nodes),
"A2": torch.arange(total_num_nodes),
}
edge_attributes = { edge_attributes = {
"A1": torch.randn(total_num_edges), "A1": torch.randn(total_num_edges),
"A2": torch.randn(total_num_edges), "A2": torch.randn(total_num_edges),
} }
graph = gb.fused_csc_sampling_graph( graph = gb.fused_csc_sampling_graph(
csc_indptr, indices, edge_attributes=edge_attributes csc_indptr,
indices,
node_attributes=node_attributes,
edge_attributes=edge_attributes,
) )
assert graph.total_num_nodes == total_num_nodes assert graph.total_num_nodes == total_num_nodes
...@@ -140,6 +147,7 @@ def test_homo_graph(total_num_nodes, total_num_edges): ...@@ -140,6 +147,7 @@ def test_homo_graph(total_num_nodes, total_num_edges):
assert torch.equal(csc_indptr, graph.csc_indptr) assert torch.equal(csc_indptr, graph.csc_indptr)
assert torch.equal(indices, graph.indices) assert torch.equal(indices, graph.indices)
assert graph.node_attributes == node_attributes
assert graph.edge_attributes == edge_attributes assert graph.edge_attributes == edge_attributes
assert graph.node_type_offset is None assert graph.node_type_offset is None
assert graph.type_per_edge is None assert graph.type_per_edge is None
...@@ -167,6 +175,10 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes): ...@@ -167,6 +175,10 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes):
) = gbt.random_hetero_graph( ) = gbt.random_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes total_num_nodes, total_num_edges, num_ntypes, num_etypes
) )
node_attributes = {
"A1": torch.arange(total_num_nodes),
"A2": torch.arange(total_num_nodes),
}
edge_attributes = { edge_attributes = {
"A1": torch.randn(total_num_edges), "A1": torch.randn(total_num_edges),
"A2": torch.randn(total_num_edges), "A2": torch.randn(total_num_edges),
...@@ -178,6 +190,7 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes): ...@@ -178,6 +190,7 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes):
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id, node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id, edge_type_to_id=edge_type_to_id,
node_attributes=node_attributes,
edge_attributes=edge_attributes, edge_attributes=edge_attributes,
) )
...@@ -188,6 +201,7 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes): ...@@ -188,6 +201,7 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes):
assert torch.equal(indices, graph.indices) assert torch.equal(indices, graph.indices)
assert torch.equal(node_type_offset, graph.node_type_offset) assert torch.equal(node_type_offset, graph.node_type_offset)
assert torch.equal(type_per_edge, graph.type_per_edge) assert torch.equal(type_per_edge, graph.type_per_edge)
assert graph.node_attributes == node_attributes
assert graph.edge_attributes == edge_attributes assert graph.edge_attributes == edge_attributes
assert node_type_to_id == graph.node_type_to_id assert node_type_to_id == graph.node_type_to_id
assert edge_type_to_id == graph.edge_type_to_id assert edge_type_to_id == graph.edge_type_to_id
...@@ -327,11 +341,32 @@ def test_node_type_offset_wrong_legnth(node_type_offset): ...@@ -327,11 +341,32 @@ def test_node_type_offset_wrong_legnth(node_type_offset):
"total_num_nodes, total_num_edges", "total_num_nodes, total_num_edges",
[(1, 1), (100, 1), (10, 50), (1000, 50000)], [(1, 1), (100, 1), (10, 50), (1000, 50000)],
) )
def test_load_save_homo_graph(total_num_nodes, total_num_edges): @pytest.mark.parametrize("has_node_attrs", [True, False])
@pytest.mark.parametrize("has_edge_attrs", [True, False])
def test_load_save_homo_graph(
total_num_nodes, total_num_edges, has_node_attrs, has_edge_attrs
):
csc_indptr, indices = gbt.random_homo_graph( csc_indptr, indices = gbt.random_homo_graph(
total_num_nodes, total_num_edges total_num_nodes, total_num_edges
) )
graph = gb.fused_csc_sampling_graph(csc_indptr, indices) node_attributes = None
if has_node_attrs:
node_attributes = {
"A": torch.arange(total_num_nodes),
"B": torch.arange(total_num_nodes),
}
edge_attributes = None
if has_edge_attrs:
edge_attributes = {
"A": torch.arange(total_num_edges),
"B": torch.arange(total_num_edges),
}
graph = gb.fused_csc_sampling_graph(
csc_indptr,
indices,
node_attributes=node_attributes,
edge_attributes=edge_attributes,
)
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
filename = os.path.join(test_dir, "fused_csc_sampling_graph.pt") filename = os.path.join(test_dir, "fused_csc_sampling_graph.pt")
...@@ -348,7 +383,22 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges): ...@@ -348,7 +383,22 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges):
assert graph.type_per_edge is None and graph2.type_per_edge is None assert graph.type_per_edge is None and graph2.type_per_edge is None
assert graph.node_type_to_id is None and graph2.node_type_to_id is None assert graph.node_type_to_id is None and graph2.node_type_to_id is None
assert graph.edge_type_to_id is None and graph2.edge_type_to_id is None assert graph.edge_type_to_id is None and graph2.edge_type_to_id is None
assert graph.edge_attributes is None and graph2.edge_attributes is None if has_node_attrs:
assert graph.node_attributes.keys() == graph2.node_attributes.keys()
for key in graph.node_attributes.keys():
assert torch.equal(
graph.node_attributes[key], graph2.node_attributes[key]
)
else:
assert graph.node_attributes is None and graph2.node_attributes is None
if has_edge_attrs:
assert graph.edge_attributes.keys() == graph2.edge_attributes.keys()
for key in graph.edge_attributes.keys():
assert torch.equal(
graph.edge_attributes[key], graph2.edge_attributes[key]
)
else:
assert graph.edge_attributes is None and graph2.edge_attributes is None
@unittest.skipIf( @unittest.skipIf(
...@@ -360,8 +410,15 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges): ...@@ -360,8 +410,15 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges):
[(1, 1), (100, 1), (10, 50), (1000, 50000)], [(1, 1), (100, 1), (10, 50), (1000, 50000)],
) )
@pytest.mark.parametrize("num_ntypes, num_etypes", [(1, 1), (3, 5), (100, 1)]) @pytest.mark.parametrize("num_ntypes, num_etypes", [(1, 1), (3, 5), (100, 1)])
@pytest.mark.parametrize("has_node_attrs", [True, False])
@pytest.mark.parametrize("has_edge_attrs", [True, False])
def test_load_save_hetero_graph( def test_load_save_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes total_num_nodes,
total_num_edges,
num_ntypes,
num_etypes,
has_node_attrs,
has_edge_attrs,
): ):
( (
csc_indptr, csc_indptr,
...@@ -373,6 +430,18 @@ def test_load_save_hetero_graph( ...@@ -373,6 +430,18 @@ def test_load_save_hetero_graph(
) = gbt.random_hetero_graph( ) = gbt.random_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes total_num_nodes, total_num_edges, num_ntypes, num_etypes
) )
node_attributes = None
if has_node_attrs:
node_attributes = {
"A": torch.arange(total_num_nodes),
"B": torch.arange(total_num_nodes),
}
edge_attributes = None
if has_edge_attrs:
edge_attributes = {
"A": torch.arange(total_num_edges),
"B": torch.arange(total_num_edges),
}
graph = gb.fused_csc_sampling_graph( graph = gb.fused_csc_sampling_graph(
csc_indptr, csc_indptr,
indices, indices,
...@@ -380,6 +449,8 @@ def test_load_save_hetero_graph( ...@@ -380,6 +449,8 @@ def test_load_save_hetero_graph(
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id, node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id, edge_type_to_id=edge_type_to_id,
node_attributes=node_attributes,
edge_attributes=edge_attributes,
) )
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
...@@ -396,6 +467,22 @@ def test_load_save_hetero_graph( ...@@ -396,6 +467,22 @@ def test_load_save_hetero_graph(
assert torch.equal(graph.type_per_edge, graph2.type_per_edge) assert torch.equal(graph.type_per_edge, graph2.type_per_edge)
assert graph.node_type_to_id == graph2.node_type_to_id assert graph.node_type_to_id == graph2.node_type_to_id
assert graph.edge_type_to_id == graph2.edge_type_to_id assert graph.edge_type_to_id == graph2.edge_type_to_id
if has_node_attrs:
assert graph.node_attributes.keys() == graph2.node_attributes.keys()
for key in graph.node_attributes.keys():
assert torch.equal(
graph.node_attributes[key], graph2.node_attributes[key]
)
else:
assert graph.node_attributes is None and graph2.node_attributes is None
if has_edge_attrs:
assert graph.edge_attributes.keys() == graph2.edge_attributes.keys()
for key in graph.edge_attributes.keys():
assert torch.equal(
graph.edge_attributes[key], graph2.edge_attributes[key]
)
else:
assert graph.edge_attributes is None and graph2.edge_attributes is None
@unittest.skipIf( @unittest.skipIf(
...@@ -406,11 +493,32 @@ def test_load_save_hetero_graph( ...@@ -406,11 +493,32 @@ def test_load_save_hetero_graph(
"total_num_nodes, total_num_edges", "total_num_nodes, total_num_edges",
[(1, 1), (100, 1), (10, 50), (1000, 50000)], [(1, 1), (100, 1), (10, 50), (1000, 50000)],
) )
def test_pickle_homo_graph(total_num_nodes, total_num_edges): @pytest.mark.parametrize("has_node_attrs", [True, False])
@pytest.mark.parametrize("has_edge_attrs", [True, False])
def test_pickle_homo_graph(
total_num_nodes, total_num_edges, has_node_attrs, has_edge_attrs
):
csc_indptr, indices = gbt.random_homo_graph( csc_indptr, indices = gbt.random_homo_graph(
total_num_nodes, total_num_edges total_num_nodes, total_num_edges
) )
graph = gb.fused_csc_sampling_graph(csc_indptr, indices) node_attributes = None
if has_node_attrs:
node_attributes = {
"A": torch.arange(total_num_nodes),
"B": torch.arange(total_num_nodes),
}
edge_attributes = None
if has_edge_attrs:
edge_attributes = {
"A": torch.arange(total_num_edges),
"B": torch.arange(total_num_edges),
}
graph = gb.fused_csc_sampling_graph(
csc_indptr,
indices,
node_attributes=node_attributes,
edge_attributes=edge_attributes,
)
serialized = pickle.dumps(graph) serialized = pickle.dumps(graph)
graph2 = pickle.loads(serialized) graph2 = pickle.loads(serialized)
...@@ -425,7 +533,22 @@ def test_pickle_homo_graph(total_num_nodes, total_num_edges): ...@@ -425,7 +533,22 @@ def test_pickle_homo_graph(total_num_nodes, total_num_edges):
assert graph.type_per_edge is None and graph2.type_per_edge is None assert graph.type_per_edge is None and graph2.type_per_edge is None
assert graph.node_type_to_id is None and graph2.node_type_to_id is None assert graph.node_type_to_id is None and graph2.node_type_to_id is None
assert graph.edge_type_to_id is None and graph2.edge_type_to_id is None assert graph.edge_type_to_id is None and graph2.edge_type_to_id is None
assert graph.edge_attributes is None and graph2.edge_attributes is None if has_node_attrs:
assert graph.node_attributes.keys() == graph2.node_attributes.keys()
for key in graph.node_attributes.keys():
assert torch.equal(
graph.node_attributes[key], graph2.node_attributes[key]
)
else:
assert graph.node_attributes is None and graph2.node_attributes is None
if has_edge_attrs:
assert graph.edge_attributes.keys() == graph2.edge_attributes.keys()
for key in graph.edge_attributes.keys():
assert torch.equal(
graph.edge_attributes[key], graph2.edge_attributes[key]
)
else:
assert graph.edge_attributes is None and graph2.edge_attributes is None
@unittest.skipIf( @unittest.skipIf(
...@@ -437,8 +560,15 @@ def test_pickle_homo_graph(total_num_nodes, total_num_edges): ...@@ -437,8 +560,15 @@ def test_pickle_homo_graph(total_num_nodes, total_num_edges):
[(1, 1), (100, 1), (10, 50), (1000, 50000)], [(1, 1), (100, 1), (10, 50), (1000, 50000)],
) )
@pytest.mark.parametrize("num_ntypes, num_etypes", [(1, 1), (3, 5), (100, 1)]) @pytest.mark.parametrize("num_ntypes, num_etypes", [(1, 1), (3, 5), (100, 1)])
@pytest.mark.parametrize("has_node_attrs", [True, False])
@pytest.mark.parametrize("has_edge_attrs", [True, False])
def test_pickle_hetero_graph( def test_pickle_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes total_num_nodes,
total_num_edges,
num_ntypes,
num_etypes,
has_node_attrs,
has_edge_attrs,
): ):
( (
csc_indptr, csc_indptr,
...@@ -450,10 +580,18 @@ def test_pickle_hetero_graph( ...@@ -450,10 +580,18 @@ def test_pickle_hetero_graph(
) = gbt.random_hetero_graph( ) = gbt.random_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes total_num_nodes, total_num_edges, num_ntypes, num_etypes
) )
edge_attributes = { node_attributes = None
"a": torch.randn((total_num_edges,)), if has_node_attrs:
"b": torch.randint(1, 10, (total_num_edges,)), node_attributes = {
} "A": torch.arange(total_num_nodes),
"B": torch.arange(total_num_nodes),
}
edge_attributes = None
if has_edge_attrs:
edge_attributes = {
"A": torch.arange(total_num_edges),
"B": torch.arange(total_num_edges),
}
graph = gb.fused_csc_sampling_graph( graph = gb.fused_csc_sampling_graph(
csc_indptr, csc_indptr,
indices, indices,
...@@ -461,6 +599,7 @@ def test_pickle_hetero_graph( ...@@ -461,6 +599,7 @@ def test_pickle_hetero_graph(
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id, node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id, edge_type_to_id=edge_type_to_id,
node_attributes=node_attributes,
edge_attributes=edge_attributes, edge_attributes=edge_attributes,
) )
...@@ -480,9 +619,22 @@ def test_pickle_hetero_graph( ...@@ -480,9 +619,22 @@ def test_pickle_hetero_graph(
assert graph.edge_type_to_id.keys() == graph2.edge_type_to_id.keys() assert graph.edge_type_to_id.keys() == graph2.edge_type_to_id.keys()
for i in graph.edge_type_to_id.keys(): for i in graph.edge_type_to_id.keys():
assert graph.edge_type_to_id[i] == graph2.edge_type_to_id[i] assert graph.edge_type_to_id[i] == graph2.edge_type_to_id[i]
assert graph.edge_attributes.keys() == graph2.edge_attributes.keys() if has_node_attrs:
for i in graph.edge_attributes.keys(): assert graph.node_attributes.keys() == graph2.node_attributes.keys()
assert torch.equal(graph.edge_attributes[i], graph2.edge_attributes[i]) for key in graph.node_attributes.keys():
assert torch.equal(
graph.node_attributes[key], graph2.node_attributes[key]
)
else:
assert graph.node_attributes is None and graph2.node_attributes is None
if has_edge_attrs:
assert graph.edge_attributes.keys() == graph2.edge_attributes.keys()
for key in graph.edge_attributes.keys():
assert torch.equal(
graph.edge_attributes[key], graph2.edge_attributes[key]
)
else:
assert graph.edge_attributes is None and graph2.edge_attributes is None
def process_csc_sampling_graph_multiprocessing(graph): def process_csc_sampling_graph_multiprocessing(graph):
...@@ -1258,6 +1410,18 @@ def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor): ...@@ -1258,6 +1410,18 @@ def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor):
t1[:] = old_t1 t1[:] = old_t1
def check_node_edge_attributes(graph1, graph2, attributes, attr_name):
for name, attr in attributes.items():
edge_attributes_1 = getattr(graph1, attr_name)
edge_attributes_2 = getattr(graph2, attr_name)
assert name in edge_attributes_1
assert name in edge_attributes_2
assert torch.equal(edge_attributes_1[name], attr)
check_tensors_on_the_same_shared_memory(
edge_attributes_1[name], edge_attributes_2[name]
)
@unittest.skipIf( @unittest.skipIf(
F._default_context_str == "gpu", F._default_context_str == "gpu",
reason="FusedCSCSamplingGraph is only supported on CPU.", reason="FusedCSCSamplingGraph is only supported on CPU.",
...@@ -1266,22 +1430,31 @@ def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor): ...@@ -1266,22 +1430,31 @@ def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor):
"total_num_nodes, total_num_edges", "total_num_nodes, total_num_edges",
[(1, 1), (100, 1), (10, 50), (1000, 50000)], [(1, 1), (100, 1), (10, 50), (1000, 50000)],
) )
@pytest.mark.parametrize("test_node_attrs", [True, False])
@pytest.mark.parametrize("test_edge_attrs", [True, False]) @pytest.mark.parametrize("test_edge_attrs", [True, False])
def test_homo_graph_on_shared_memory( def test_homo_graph_on_shared_memory(
total_num_nodes, total_num_edges, test_edge_attrs total_num_nodes, total_num_edges, test_node_attrs, test_edge_attrs
): ):
csc_indptr, indices = gbt.random_homo_graph( csc_indptr, indices = gbt.random_homo_graph(
total_num_nodes, total_num_edges total_num_nodes, total_num_edges
) )
node_attributes = None
if test_node_attrs:
node_attributes = {
"A1": torch.arange(total_num_nodes),
"A2": torch.arange(total_num_nodes),
}
edge_attributes = None
if test_edge_attrs: if test_edge_attrs:
edge_attributes = { edge_attributes = {
"A1": torch.randn(total_num_edges), "A1": torch.randn(total_num_edges),
"A2": torch.randn(total_num_edges), "A2": torch.randn(total_num_edges),
} }
else:
edge_attributes = None
graph = gb.fused_csc_sampling_graph( graph = gb.fused_csc_sampling_graph(
csc_indptr, indices, edge_attributes=edge_attributes csc_indptr,
indices,
node_attributes=node_attributes,
edge_attributes=edge_attributes,
) )
shm_name = "test_homo_g" shm_name = "test_homo_g"
...@@ -1307,14 +1480,14 @@ def test_homo_graph_on_shared_memory( ...@@ -1307,14 +1480,14 @@ def test_homo_graph_on_shared_memory(
) )
check_tensors_on_the_same_shared_memory(graph1.indices, graph2.indices) check_tensors_on_the_same_shared_memory(graph1.indices, graph2.indices)
if test_node_attrs:
check_node_edge_attributes(
graph1, graph2, node_attributes, "node_attributes"
)
if test_edge_attrs: if test_edge_attrs:
for name, edge_attr in edge_attributes.items(): check_node_edge_attributes(
assert name in graph1.edge_attributes graph1, graph2, edge_attributes, "edge_attributes"
assert name in graph2.edge_attributes )
assert torch.equal(graph1.edge_attributes[name], edge_attr)
check_tensors_on_the_same_shared_memory(
graph1.edge_attributes[name], graph2.edge_attributes[name]
)
assert graph1.node_type_offset is None and graph2.node_type_offset is None assert graph1.node_type_offset is None and graph2.node_type_offset is None
assert graph1.type_per_edge is None and graph2.type_per_edge is None assert graph1.type_per_edge is None and graph2.type_per_edge is None
...@@ -1333,9 +1506,15 @@ def test_homo_graph_on_shared_memory( ...@@ -1333,9 +1506,15 @@ def test_homo_graph_on_shared_memory(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_ntypes, num_etypes", [(1, 1), (3, 5), (100, 1), (1000, 1000)] "num_ntypes, num_etypes", [(1, 1), (3, 5), (100, 1), (1000, 1000)]
) )
@pytest.mark.parametrize("test_node_attrs", [True, False])
@pytest.mark.parametrize("test_edge_attrs", [True, False]) @pytest.mark.parametrize("test_edge_attrs", [True, False])
def test_hetero_graph_on_shared_memory( def test_hetero_graph_on_shared_memory(
total_num_nodes, total_num_edges, num_ntypes, num_etypes, test_edge_attrs total_num_nodes,
total_num_edges,
num_ntypes,
num_etypes,
test_node_attrs,
test_edge_attrs,
): ):
( (
csc_indptr, csc_indptr,
...@@ -1348,13 +1527,20 @@ def test_hetero_graph_on_shared_memory( ...@@ -1348,13 +1527,20 @@ def test_hetero_graph_on_shared_memory(
total_num_nodes, total_num_edges, num_ntypes, num_etypes total_num_nodes, total_num_edges, num_ntypes, num_etypes
) )
node_attributes = None
if test_node_attrs:
node_attributes = {
"A1": torch.arange(total_num_nodes),
"A2": torch.arange(total_num_nodes),
}
edge_attributes = None
if test_edge_attrs: if test_edge_attrs:
edge_attributes = { edge_attributes = {
"A1": torch.randn(total_num_edges), "A1": torch.randn(total_num_edges),
"A2": torch.randn(total_num_edges), "A2": torch.randn(total_num_edges),
} }
else:
edge_attributes = None
graph = gb.fused_csc_sampling_graph( graph = gb.fused_csc_sampling_graph(
csc_indptr, csc_indptr,
indices, indices,
...@@ -1362,6 +1548,7 @@ def test_hetero_graph_on_shared_memory( ...@@ -1362,6 +1548,7 @@ def test_hetero_graph_on_shared_memory(
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id, node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id, edge_type_to_id=edge_type_to_id,
node_attributes=node_attributes,
edge_attributes=edge_attributes, edge_attributes=edge_attributes,
) )
...@@ -1398,14 +1585,14 @@ def test_hetero_graph_on_shared_memory( ...@@ -1398,14 +1585,14 @@ def test_hetero_graph_on_shared_memory(
graph1.type_per_edge, graph2.type_per_edge graph1.type_per_edge, graph2.type_per_edge
) )
if test_node_attrs:
check_node_edge_attributes(
graph1, graph2, node_attributes, "node_attributes"
)
if test_edge_attrs: if test_edge_attrs:
for name, edge_attr in edge_attributes.items(): check_node_edge_attributes(
assert name in graph1.edge_attributes graph1, graph2, edge_attributes, "edge_attributes"
assert name in graph2.edge_attributes )
assert torch.equal(graph1.edge_attributes[name], edge_attr)
check_tensors_on_the_same_shared_memory(
graph1.edge_attributes[name], graph2.edge_attributes[name]
)
assert node_type_to_id == graph1.node_type_to_id assert node_type_to_id == graph1.node_type_to_id
assert edge_type_to_id == graph1.edge_type_to_id assert edge_type_to_id == graph1.edge_type_to_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment