"src/array/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "889798fec5323070bef950c23f7f1d36a22588b7"
Unverified Commit 4b456507 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Add edge attributes (#5966)

parent aa795b28
...@@ -33,6 +33,7 @@ namespace sampling { ...@@ -33,6 +33,7 @@ namespace sampling {
*/ */
class CSCSamplingGraph : public torch::CustomClassHolder { class CSCSamplingGraph : public torch::CustomClassHolder {
public: public:
using EdgeAttrMap = torch::Dict<std::string, torch::Tensor>;
/** @brief Default constructor. */ /** @brief Default constructor. */
CSCSamplingGraph() = default; CSCSamplingGraph() = default;
...@@ -48,7 +49,8 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -48,7 +49,8 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
CSCSamplingGraph( CSCSamplingGraph(
const torch::Tensor& indptr, const torch::Tensor& indices, const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset, const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge); const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<EdgeAttrMap>& edge_attributes);
/** /**
* @brief Create a homogeneous CSC graph from tensors of CSC format. * @brief Create a homogeneous CSC graph from tensors of CSC format.
...@@ -64,7 +66,8 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -64,7 +66,8 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
static c10::intrusive_ptr<CSCSamplingGraph> FromCSC( static c10::intrusive_ptr<CSCSamplingGraph> FromCSC(
const torch::Tensor& indptr, const torch::Tensor& indices, const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset, const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge); const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<EdgeAttrMap>& edge_attributes);
/** @brief Get the number of nodes. */ /** @brief Get the number of nodes. */
int64_t NumNodes() const { return indptr_.size(0) - 1; } int64_t NumNodes() const { return indptr_.size(0) - 1; }
...@@ -88,6 +91,11 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -88,6 +91,11 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
return type_per_edge_; return type_per_edge_;
} }
/** @brief Get the edge attributes dictionary. */
inline const torch::optional<EdgeAttrMap> EdgeAttributes() const {
return edge_attributes_;
}
/** /**
* @brief Magic number to indicate graph version in serialize/deserialize * @brief Magic number to indicate graph version in serialize/deserialize
* stage. * stage.
...@@ -231,6 +239,13 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -231,6 +239,13 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
*/ */
torch::optional<torch::Tensor> type_per_edge_; torch::optional<torch::Tensor> type_per_edge_;
/**
* @brief A dictionary of edge attributes. Each key represents the attribute's
* name, while the corresponding value holds the attribute's specific value.
* The length of each value should match the total number of edges."
*/
torch::optional<EdgeAttrMap> edge_attributes_;
/** /**
* @brief Maximum number of bytes used to serialize the metadata of the * @brief Maximum number of bytes used to serialize the metadata of the
* member tensors, including tensor shape and dtype. The constant is estimated * member tensors, including tensor shape and dtype. The constant is estimated
......
...@@ -20,11 +20,13 @@ namespace sampling { ...@@ -20,11 +20,13 @@ namespace sampling {
CSCSamplingGraph::CSCSamplingGraph( CSCSamplingGraph::CSCSamplingGraph(
const torch::Tensor& indptr, const torch::Tensor& indices, const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset, const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge) const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<EdgeAttrMap>& edge_attributes)
: indptr_(indptr), : indptr_(indptr),
indices_(indices), indices_(indices),
node_type_offset_(node_type_offset), node_type_offset_(node_type_offset),
type_per_edge_(type_per_edge) { type_per_edge_(type_per_edge),
edge_attributes_(edge_attributes) {
TORCH_CHECK(indptr.dim() == 1); TORCH_CHECK(indptr.dim() == 1);
TORCH_CHECK(indices.dim() == 1); TORCH_CHECK(indices.dim() == 1);
TORCH_CHECK(indptr.device() == indices.device()); TORCH_CHECK(indptr.device() == indices.device());
...@@ -33,7 +35,8 @@ CSCSamplingGraph::CSCSamplingGraph( ...@@ -33,7 +35,8 @@ CSCSamplingGraph::CSCSamplingGraph(
c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSC( c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSC(
const torch::Tensor& indptr, const torch::Tensor& indices, const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset, const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge) { const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<EdgeAttrMap>& edge_attributes) {
if (node_type_offset.has_value()) { if (node_type_offset.has_value()) {
auto& offset = node_type_offset.value(); auto& offset = node_type_offset.value();
TORCH_CHECK(offset.dim() == 1); TORCH_CHECK(offset.dim() == 1);
...@@ -42,9 +45,13 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSC( ...@@ -42,9 +45,13 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSC(
TORCH_CHECK(type_per_edge.value().dim() == 1); TORCH_CHECK(type_per_edge.value().dim() == 1);
TORCH_CHECK(type_per_edge.value().size(0) == indices.size(0)); TORCH_CHECK(type_per_edge.value().size(0) == indices.size(0));
} }
if (edge_attributes.has_value()) {
for (const auto& pair : edge_attributes.value()) {
TORCH_CHECK(pair.value().size(0) == indices.size(0));
}
}
return c10::make_intrusive<CSCSamplingGraph>( return c10::make_intrusive<CSCSamplingGraph>(
indptr, indices, node_type_offset, type_per_edge); indptr, indices, node_type_offset, type_per_edge, edge_attributes);
} }
void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) { void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
...@@ -217,7 +224,7 @@ CSCSamplingGraph::BuildGraphFromSharedMemoryTensors( ...@@ -217,7 +224,7 @@ CSCSamplingGraph::BuildGraphFromSharedMemoryTensors(
auto& optional_tensors = std::get<2>(shared_memory_tensors); auto& optional_tensors = std::get<2>(shared_memory_tensors);
auto graph = c10::make_intrusive<CSCSamplingGraph>( auto graph = c10::make_intrusive<CSCSamplingGraph>(
optional_tensors[0].value(), optional_tensors[1].value(), optional_tensors[0].value(), optional_tensors[1].value(),
optional_tensors[2], optional_tensors[3]); optional_tensors[2], optional_tensors[3], torch::nullopt);
graph->tensor_meta_shm_ = std::move(std::get<0>(shared_memory_tensors)); graph->tensor_meta_shm_ = std::move(std::get<0>(shared_memory_tensors));
graph->tensor_data_shm_ = std::move(std::get<1>(shared_memory_tensors)); graph->tensor_data_shm_ = std::move(std::get<1>(shared_memory_tensors));
return graph; return graph;
......
...@@ -28,6 +28,7 @@ TORCH_LIBRARY(graphbolt, m) { ...@@ -28,6 +28,7 @@ TORCH_LIBRARY(graphbolt, m) {
.def("indices", &CSCSamplingGraph::Indices) .def("indices", &CSCSamplingGraph::Indices)
.def("node_type_offset", &CSCSamplingGraph::NodeTypeOffset) .def("node_type_offset", &CSCSamplingGraph::NodeTypeOffset)
.def("type_per_edge", &CSCSamplingGraph::TypePerEdge) .def("type_per_edge", &CSCSamplingGraph::TypePerEdge)
.def("edge_attributes", &CSCSamplingGraph::EdgeAttributes)
.def("in_subgraph", &CSCSamplingGraph::InSubgraph) .def("in_subgraph", &CSCSamplingGraph::InSubgraph)
.def("sample_neighbors", &CSCSamplingGraph::SampleNeighbors) .def("sample_neighbors", &CSCSamplingGraph::SampleNeighbors)
.def( .def(
......
...@@ -163,6 +163,20 @@ class CSCSamplingGraph: ...@@ -163,6 +163,20 @@ class CSCSamplingGraph:
""" """
return self._c_csc_graph.type_per_edge() return self._c_csc_graph.type_per_edge()
@property
def edge_attributes(self) -> Optional[Dict[str, torch.Tensor]]:
"""Returns the edge attributes dictionary.
Returns
-------
torch.Tensor or None
If present, returns a dictionary of edge attributes. Each key
represents the attribute's name, while the corresponding value
holds the attribute's specific value. The length of each value
should match the total number of edges."
"""
return self._c_csc_graph.edge_attributes()
@property @property
def metadata(self) -> Optional[GraphMetadata]: def metadata(self) -> Optional[GraphMetadata]:
"""Returns the metadata of the graph. """Returns the metadata of the graph.
...@@ -383,6 +397,7 @@ def from_csc( ...@@ -383,6 +397,7 @@ def from_csc(
indices: torch.Tensor, indices: torch.Tensor,
node_type_offset: Optional[torch.tensor] = None, node_type_offset: Optional[torch.tensor] = None,
type_per_edge: Optional[torch.tensor] = None, type_per_edge: Optional[torch.tensor] = None,
edge_attributes: Optional[Dict[str, torch.tensor]] = None,
metadata: Optional[GraphMetadata] = None, metadata: Optional[GraphMetadata] = None,
) -> CSCSamplingGraph: ) -> CSCSamplingGraph:
"""Create a CSCSamplingGraph object from a CSC representation. """Create a CSCSamplingGraph object from a CSC representation.
...@@ -399,6 +414,8 @@ def from_csc( ...@@ -399,6 +414,8 @@ def from_csc(
Offset of node types in the graph, by default None. Offset of node types in the graph, by default None.
type_per_edge : Optional[torch.tensor], optional type_per_edge : Optional[torch.tensor], optional
Type ids of each edge in the graph, by default None. Type ids of each edge in the graph, by default None.
edge_attributes: Optional[Dict[str, torch.tensor]], optional
Edge attributes of the graph, by default None.
metadata: Optional[GraphMetadata], optional metadata: Optional[GraphMetadata], optional
Metadata of the graph, by default None. Metadata of the graph, by default None.
Returns Returns
...@@ -416,7 +433,7 @@ def from_csc( ...@@ -416,7 +433,7 @@ def from_csc(
>>> node_type_offset = torch.tensor([0, 1, 2, 3]) >>> node_type_offset = torch.tensor([0, 1, 2, 3])
>>> type_per_edge = torch.tensor([0, 1, 0, 1, 1, 0, 0]) >>> type_per_edge = torch.tensor([0, 1, 0, 1, 1, 0, 0])
>>> graph = graphbolt.from_csc(csc_indptr, indices, node_type_offset, \ >>> graph = graphbolt.from_csc(csc_indptr, indices, node_type_offset, \
>>> type_per_edge, metadata) >>> type_per_edge, None, metadata)
>>> print(graph) >>> print(graph)
CSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]), CSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
indices=tensor([1, 3, 0, 1, 2, 0, 3]), indices=tensor([1, 3, 0, 1, 2, 0, 3]),
...@@ -428,7 +445,11 @@ def from_csc( ...@@ -428,7 +445,11 @@ def from_csc(
), "node_type_offset length should be |ntypes| + 1." ), "node_type_offset length should be |ntypes| + 1."
return CSCSamplingGraph( return CSCSamplingGraph(
torch.ops.graphbolt.from_csc( torch.ops.graphbolt.from_csc(
csc_indptr, indices, node_type_offset, type_per_edge csc_indptr,
indices,
node_type_offset,
type_per_edge,
edge_attributes,
), ),
metadata, metadata,
) )
...@@ -535,7 +556,11 @@ def from_dglgraph(g: DGLGraph) -> CSCSamplingGraph: ...@@ -535,7 +556,11 @@ def from_dglgraph(g: DGLGraph) -> CSCSamplingGraph:
return CSCSamplingGraph( return CSCSamplingGraph(
torch.ops.graphbolt.from_csc( torch.ops.graphbolt.from_csc(
indptr, indices, node_type_offset, type_per_edge indptr,
indices,
node_type_offset,
type_per_edge,
None,
), ),
metadata, metadata,
) )
...@@ -53,6 +53,7 @@ def test_hetero_empty_graph(num_nodes): ...@@ -53,6 +53,7 @@ def test_hetero_empty_graph(num_nodes):
indices, indices,
node_type_offset, node_type_offset,
type_per_edge, type_per_edge,
None,
metadata, metadata,
) )
assert graph.num_edges == 0 assert graph.num_edges == 0
...@@ -107,7 +108,11 @@ def test_metadata_with_etype_exception(etypes): ...@@ -107,7 +108,11 @@ def test_metadata_with_etype_exception(etypes):
) )
def test_homo_graph(num_nodes, num_edges): def test_homo_graph(num_nodes, num_edges):
csc_indptr, indices = gbt.random_homo_graph(num_nodes, num_edges) csc_indptr, indices = gbt.random_homo_graph(num_nodes, num_edges)
graph = gb.from_csc(csc_indptr, indices) edge_attributes = {
"A1": torch.randn(num_edges),
"A2": torch.randn(num_edges),
}
graph = gb.from_csc(csc_indptr, indices, edge_attributes=edge_attributes)
assert graph.num_nodes == num_nodes assert graph.num_nodes == num_nodes
assert graph.num_edges == num_edges assert graph.num_edges == num_edges
...@@ -115,6 +120,7 @@ def test_homo_graph(num_nodes, num_edges): ...@@ -115,6 +120,7 @@ def test_homo_graph(num_nodes, num_edges):
assert torch.equal(csc_indptr, graph.csc_indptr) assert torch.equal(csc_indptr, graph.csc_indptr)
assert torch.equal(indices, graph.indices) assert torch.equal(indices, graph.indices)
assert graph.edge_attributes == edge_attributes
assert graph.metadata is None assert graph.metadata is None
assert graph.node_type_offset is None assert graph.node_type_offset is None
assert graph.type_per_edge is None assert graph.type_per_edge is None
...@@ -136,8 +142,17 @@ def test_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes): ...@@ -136,8 +142,17 @@ def test_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
type_per_edge, type_per_edge,
metadata, metadata,
) = gbt.random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes) ) = gbt.random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes)
edge_attributes = {
"A1": torch.randn(num_edges),
"A2": torch.randn(num_edges),
}
graph = gb.from_csc( graph = gb.from_csc(
csc_indptr, indices, node_type_offset, type_per_edge, metadata csc_indptr,
indices,
node_type_offset,
type_per_edge,
edge_attributes,
metadata,
) )
assert graph.num_nodes == num_nodes assert graph.num_nodes == num_nodes
...@@ -147,6 +162,7 @@ def test_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes): ...@@ -147,6 +162,7 @@ def test_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
assert torch.equal(indices, graph.indices) assert torch.equal(indices, graph.indices)
assert torch.equal(node_type_offset, graph.node_type_offset) assert torch.equal(node_type_offset, graph.node_type_offset)
assert torch.equal(type_per_edge, graph.type_per_edge) assert torch.equal(type_per_edge, graph.type_per_edge)
assert graph.edge_attributes == edge_attributes
assert metadata.node_type_to_id == graph.metadata.node_type_to_id assert metadata.node_type_to_id == graph.metadata.node_type_to_id
assert metadata.edge_type_to_id == graph.metadata.edge_type_to_id assert metadata.edge_type_to_id == graph.metadata.edge_type_to_id
...@@ -170,7 +186,7 @@ def test_node_type_offset_wrong_legnth(node_type_offset): ...@@ -170,7 +186,7 @@ def test_node_type_offset_wrong_legnth(node_type_offset):
) )
with pytest.raises(Exception): with pytest.raises(Exception):
gb.from_csc( gb.from_csc(
csc_indptr, indices, node_type_offset, type_per_edge, metadata csc_indptr, indices, node_type_offset, type_per_edge, None, metadata
) )
...@@ -218,7 +234,7 @@ def test_load_save_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes): ...@@ -218,7 +234,7 @@ def test_load_save_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
metadata, metadata,
) = gbt.random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes) ) = gbt.random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes)
graph = gb.from_csc( graph = gb.from_csc(
csc_indptr, indices, node_type_offset, type_per_edge, metadata csc_indptr, indices, node_type_offset, type_per_edge, None, metadata
) )
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
...@@ -323,7 +339,7 @@ def test_in_subgraph_heterogeneous(): ...@@ -323,7 +339,7 @@ def test_in_subgraph_heterogeneous():
# Construct CSCSamplingGraph. # Construct CSCSamplingGraph.
metadata = gb.GraphMetadata(ntypes, etypes) metadata = gb.GraphMetadata(ntypes, etypes)
graph = gb.from_csc( graph = gb.from_csc(
indptr, indices, node_type_offset, type_per_edge, metadata indptr, indices, node_type_offset, type_per_edge, None, metadata
) )
# Extract in subgraph. # Extract in subgraph.
...@@ -662,7 +678,7 @@ def test_hetero_graph_on_shared_memory( ...@@ -662,7 +678,7 @@ def test_hetero_graph_on_shared_memory(
metadata, metadata,
) = gbt.random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes) ) = gbt.random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes)
graph = gb.from_csc( graph = gb.from_csc(
csc_indptr, indices, node_type_offset, type_per_edge, metadata csc_indptr, indices, node_type_offset, type_per_edge, None, metadata
) )
shm_name = "test_hetero_g" shm_name = "test_hetero_g"
......
...@@ -684,7 +684,7 @@ def test_OnDiskDataset_Graph_heterogeneous(): ...@@ -684,7 +684,7 @@ def test_OnDiskDataset_Graph_heterogeneous():
metadata, metadata,
) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4) ) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4)
graph = gb.from_csc( graph = gb.from_csc(
csc_indptr, indices, node_type_offset, type_per_edge, metadata csc_indptr, indices, node_type_offset, type_per_edge, None, metadata
) )
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment