"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9efe1e52c3f20ed02ad91097b3e471a45e1ba8ed"
Unverified Commit 4b456507 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Add edge attributes (#5966)

parent aa795b28
......@@ -33,6 +33,7 @@ namespace sampling {
*/
class CSCSamplingGraph : public torch::CustomClassHolder {
public:
using EdgeAttrMap = torch::Dict<std::string, torch::Tensor>;
/** @brief Default constructor. */
CSCSamplingGraph() = default;
......@@ -48,7 +49,8 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
CSCSamplingGraph(
const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge);
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<EdgeAttrMap>& edge_attributes);
/**
* @brief Create a homogeneous CSC graph from tensors of CSC format.
......@@ -64,7 +66,8 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
static c10::intrusive_ptr<CSCSamplingGraph> FromCSC(
const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge);
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<EdgeAttrMap>& edge_attributes);
/** @brief Get the number of nodes. */
int64_t NumNodes() const { return indptr_.size(0) - 1; }
......@@ -88,6 +91,11 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
return type_per_edge_;
}
/** @brief Get the edge attributes dictionary. */
inline const torch::optional<EdgeAttrMap> EdgeAttributes() const {
return edge_attributes_;
}
/**
* @brief Magic number to indicate graph version in serialize/deserialize
* stage.
......@@ -231,6 +239,13 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
*/
torch::optional<torch::Tensor> type_per_edge_;
/**
* @brief A dictionary of edge attributes. Each key represents the attribute's
* name, while the corresponding value holds the attribute's specific value.
* The length of each value should match the total number of edges."
*/
torch::optional<EdgeAttrMap> edge_attributes_;
/**
* @brief Maximum number of bytes used to serialize the metadata of the
* member tensors, including tensor shape and dtype. The constant is estimated
......
......@@ -20,11 +20,13 @@ namespace sampling {
CSCSamplingGraph::CSCSamplingGraph(
const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge)
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<EdgeAttrMap>& edge_attributes)
: indptr_(indptr),
indices_(indices),
node_type_offset_(node_type_offset),
type_per_edge_(type_per_edge) {
type_per_edge_(type_per_edge),
edge_attributes_(edge_attributes) {
TORCH_CHECK(indptr.dim() == 1);
TORCH_CHECK(indices.dim() == 1);
TORCH_CHECK(indptr.device() == indices.device());
......@@ -33,7 +35,8 @@ CSCSamplingGraph::CSCSamplingGraph(
c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSC(
const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge) {
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<EdgeAttrMap>& edge_attributes) {
if (node_type_offset.has_value()) {
auto& offset = node_type_offset.value();
TORCH_CHECK(offset.dim() == 1);
......@@ -42,9 +45,13 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSC(
TORCH_CHECK(type_per_edge.value().dim() == 1);
TORCH_CHECK(type_per_edge.value().size(0) == indices.size(0));
}
if (edge_attributes.has_value()) {
for (const auto& pair : edge_attributes.value()) {
TORCH_CHECK(pair.value().size(0) == indices.size(0));
}
}
return c10::make_intrusive<CSCSamplingGraph>(
indptr, indices, node_type_offset, type_per_edge);
indptr, indices, node_type_offset, type_per_edge, edge_attributes);
}
void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
......@@ -217,7 +224,7 @@ CSCSamplingGraph::BuildGraphFromSharedMemoryTensors(
auto& optional_tensors = std::get<2>(shared_memory_tensors);
auto graph = c10::make_intrusive<CSCSamplingGraph>(
optional_tensors[0].value(), optional_tensors[1].value(),
optional_tensors[2], optional_tensors[3]);
optional_tensors[2], optional_tensors[3], torch::nullopt);
graph->tensor_meta_shm_ = std::move(std::get<0>(shared_memory_tensors));
graph->tensor_data_shm_ = std::move(std::get<1>(shared_memory_tensors));
return graph;
......
......@@ -28,6 +28,7 @@ TORCH_LIBRARY(graphbolt, m) {
.def("indices", &CSCSamplingGraph::Indices)
.def("node_type_offset", &CSCSamplingGraph::NodeTypeOffset)
.def("type_per_edge", &CSCSamplingGraph::TypePerEdge)
.def("edge_attributes", &CSCSamplingGraph::EdgeAttributes)
.def("in_subgraph", &CSCSamplingGraph::InSubgraph)
.def("sample_neighbors", &CSCSamplingGraph::SampleNeighbors)
.def(
......
......@@ -163,6 +163,20 @@ class CSCSamplingGraph:
"""
return self._c_csc_graph.type_per_edge()
@property
def edge_attributes(self) -> Optional[Dict[str, torch.Tensor]]:
"""Returns the edge attributes dictionary.
Returns
-------
torch.Tensor or None
If present, returns a dictionary of edge attributes. Each key
represents the attribute's name, while the corresponding value
holds the attribute's specific value. The length of each value
should match the total number of edges."
"""
return self._c_csc_graph.edge_attributes()
@property
def metadata(self) -> Optional[GraphMetadata]:
"""Returns the metadata of the graph.
......@@ -383,6 +397,7 @@ def from_csc(
indices: torch.Tensor,
node_type_offset: Optional[torch.tensor] = None,
type_per_edge: Optional[torch.tensor] = None,
edge_attributes: Optional[Dict[str, torch.tensor]] = None,
metadata: Optional[GraphMetadata] = None,
) -> CSCSamplingGraph:
"""Create a CSCSamplingGraph object from a CSC representation.
......@@ -399,6 +414,8 @@ def from_csc(
Offset of node types in the graph, by default None.
type_per_edge : Optional[torch.tensor], optional
Type ids of each edge in the graph, by default None.
edge_attributes: Optional[Dict[str, torch.tensor]], optional
Edge attributes of the graph, by default None.
metadata: Optional[GraphMetadata], optional
Metadata of the graph, by default None.
Returns
......@@ -416,7 +433,7 @@ def from_csc(
>>> node_type_offset = torch.tensor([0, 1, 2, 3])
>>> type_per_edge = torch.tensor([0, 1, 0, 1, 1, 0, 0])
>>> graph = graphbolt.from_csc(csc_indptr, indices, node_type_offset, \
>>> type_per_edge, metadata)
>>> type_per_edge, None, metadata)
>>> print(graph)
CSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
indices=tensor([1, 3, 0, 1, 2, 0, 3]),
......@@ -428,7 +445,11 @@ def from_csc(
), "node_type_offset length should be |ntypes| + 1."
return CSCSamplingGraph(
torch.ops.graphbolt.from_csc(
csc_indptr, indices, node_type_offset, type_per_edge
csc_indptr,
indices,
node_type_offset,
type_per_edge,
edge_attributes,
),
metadata,
)
......@@ -535,7 +556,11 @@ def from_dglgraph(g: DGLGraph) -> CSCSamplingGraph:
return CSCSamplingGraph(
torch.ops.graphbolt.from_csc(
indptr, indices, node_type_offset, type_per_edge
indptr,
indices,
node_type_offset,
type_per_edge,
None,
),
metadata,
)
......@@ -53,6 +53,7 @@ def test_hetero_empty_graph(num_nodes):
indices,
node_type_offset,
type_per_edge,
None,
metadata,
)
assert graph.num_edges == 0
......@@ -107,7 +108,11 @@ def test_metadata_with_etype_exception(etypes):
)
def test_homo_graph(num_nodes, num_edges):
csc_indptr, indices = gbt.random_homo_graph(num_nodes, num_edges)
graph = gb.from_csc(csc_indptr, indices)
edge_attributes = {
"A1": torch.randn(num_edges),
"A2": torch.randn(num_edges),
}
graph = gb.from_csc(csc_indptr, indices, edge_attributes=edge_attributes)
assert graph.num_nodes == num_nodes
assert graph.num_edges == num_edges
......@@ -115,6 +120,7 @@ def test_homo_graph(num_nodes, num_edges):
assert torch.equal(csc_indptr, graph.csc_indptr)
assert torch.equal(indices, graph.indices)
assert graph.edge_attributes == edge_attributes
assert graph.metadata is None
assert graph.node_type_offset is None
assert graph.type_per_edge is None
......@@ -136,8 +142,17 @@ def test_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
type_per_edge,
metadata,
) = gbt.random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes)
edge_attributes = {
"A1": torch.randn(num_edges),
"A2": torch.randn(num_edges),
}
graph = gb.from_csc(
csc_indptr, indices, node_type_offset, type_per_edge, metadata
csc_indptr,
indices,
node_type_offset,
type_per_edge,
edge_attributes,
metadata,
)
assert graph.num_nodes == num_nodes
......@@ -147,6 +162,7 @@ def test_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
assert torch.equal(indices, graph.indices)
assert torch.equal(node_type_offset, graph.node_type_offset)
assert torch.equal(type_per_edge, graph.type_per_edge)
assert graph.edge_attributes == edge_attributes
assert metadata.node_type_to_id == graph.metadata.node_type_to_id
assert metadata.edge_type_to_id == graph.metadata.edge_type_to_id
......@@ -170,7 +186,7 @@ def test_node_type_offset_wrong_legnth(node_type_offset):
)
with pytest.raises(Exception):
gb.from_csc(
csc_indptr, indices, node_type_offset, type_per_edge, metadata
csc_indptr, indices, node_type_offset, type_per_edge, None, metadata
)
......@@ -218,7 +234,7 @@ def test_load_save_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
metadata,
) = gbt.random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes)
graph = gb.from_csc(
csc_indptr, indices, node_type_offset, type_per_edge, metadata
csc_indptr, indices, node_type_offset, type_per_edge, None, metadata
)
with tempfile.TemporaryDirectory() as test_dir:
......@@ -323,7 +339,7 @@ def test_in_subgraph_heterogeneous():
# Construct CSCSamplingGraph.
metadata = gb.GraphMetadata(ntypes, etypes)
graph = gb.from_csc(
indptr, indices, node_type_offset, type_per_edge, metadata
indptr, indices, node_type_offset, type_per_edge, None, metadata
)
# Extract in subgraph.
......@@ -662,7 +678,7 @@ def test_hetero_graph_on_shared_memory(
metadata,
) = gbt.random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes)
graph = gb.from_csc(
csc_indptr, indices, node_type_offset, type_per_edge, metadata
csc_indptr, indices, node_type_offset, type_per_edge, None, metadata
)
shm_name = "test_hetero_g"
......
......@@ -684,7 +684,7 @@ def test_OnDiskDataset_Graph_heterogeneous():
metadata,
) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4)
graph = gb.from_csc(
csc_indptr, indices, node_type_offset, type_per_edge, metadata
csc_indptr, indices, node_type_offset, type_per_edge, None, metadata
)
with tempfile.TemporaryDirectory() as test_dir:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment