subgraph.cc 5.36 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
/*!
 *  Copyright (c) 2020 by Contributors
 * \file graph/subgraph.cc
 * \brief Functions for extracting subgraphs.
 */
#include "./heterograph.h"
using namespace dgl::runtime;

namespace dgl {

11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
HeteroSubgraph InEdgeGraphRelabelNodes(
    const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
  CHECK_EQ(vids.size(), graph->NumVertexTypes())
    << "Invalid input: the input list size must be the same as the number of vertex types.";
  std::vector<IdArray> eids(graph->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
    auto pair = graph->meta_graph()->FindEdge(etype);
    const dgl_type_t dst_vtype = pair.second;
    if (aten::IsNullArray(vids[dst_vtype])) {
      eids[etype] = IdArray::Empty({0}, graph->DataType(), graph->Context());
    } else {
      const auto& earr = graph->InEdges(etype, {vids[dst_vtype]});
      eids[etype] = earr.id;
    }
  }
  return graph->EdgeSubgraph(eids, false);
}

HeteroSubgraph InEdgeGraphNoRelabelNodes(
    const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
  // TODO(mufei): This should also use EdgeSubgraph once it is supported for CSR graphs
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
  CHECK_EQ(vids.size(), graph->NumVertexTypes())
    << "Invalid input: the input list size must be the same as the number of vertex types.";
  std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes());
  std::vector<IdArray> induced_edges(graph->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
    auto pair = graph->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
    auto relgraph = graph->GetRelationGraph(etype);
    if (aten::IsNullArray(vids[dst_vtype])) {
      // create a placeholder graph
      subrels[etype] = UnitGraph::Empty(
        relgraph->NumVertexTypes(),
        graph->NumVertices(src_vtype),
        graph->NumVertices(dst_vtype),
        graph->DataType(), graph->Context());
      induced_edges[etype] = IdArray::Empty({0}, graph->DataType(), graph->Context());
    } else {
      const auto& earr = graph->InEdges(etype, {vids[dst_vtype]});
      subrels[etype] = UnitGraph::CreateFromCOO(
        relgraph->NumVertexTypes(),
        graph->NumVertices(src_vtype),
        graph->NumVertices(dst_vtype),
        earr.src,
        earr.dst);
      induced_edges[etype] = earr.id;
    }
  }
  HeteroSubgraph ret;
61
  ret.graph = CreateHeteroGraph(graph->meta_graph(), subrels, graph->NumVerticesPerType());
62
63
64
65
  ret.induced_edges = std::move(induced_edges);
  return ret;
}

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
HeteroSubgraph InEdgeGraph(
    const HeteroGraphPtr graph, const std::vector<IdArray>& vids, bool relabel_nodes) {
  if (relabel_nodes) {
    return InEdgeGraphRelabelNodes(graph, vids);
  } else {
    return InEdgeGraphNoRelabelNodes(graph, vids);
  }
}

HeteroSubgraph OutEdgeGraphRelabelNodes(
    const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
  CHECK_EQ(vids.size(), graph->NumVertexTypes())
    << "Invalid input: the input list size must be the same as the number of vertex types.";
  std::vector<IdArray> eids(graph->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
    auto pair = graph->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    if (aten::IsNullArray(vids[src_vtype])) {
      eids[etype] = IdArray::Empty({0}, graph->DataType(), graph->Context());
    } else {
      const auto& earr = graph->OutEdges(etype, {vids[src_vtype]});
      eids[etype] = earr.id;
    }
  }
  return graph->EdgeSubgraph(eids, false);
}

HeteroSubgraph OutEdgeGraphNoRelabelNodes(
    const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
  // TODO(mufei): This should also use EdgeSubgraph once it is supported for CSR graphs
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
  CHECK_EQ(vids.size(), graph->NumVertexTypes())
    << "Invalid input: the input list size must be the same as the number of vertex types.";
  std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes());
  std::vector<IdArray> induced_edges(graph->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
    auto pair = graph->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
    auto relgraph = graph->GetRelationGraph(etype);
    if (aten::IsNullArray(vids[src_vtype])) {
      // create a placeholder graph
      subrels[etype] = UnitGraph::Empty(
        relgraph->NumVertexTypes(),
        graph->NumVertices(src_vtype),
        graph->NumVertices(dst_vtype),
        graph->DataType(), graph->Context());
      induced_edges[etype] = IdArray::Empty({0}, graph->DataType(), graph->Context());
    } else {
      const auto& earr = graph->OutEdges(etype, {vids[src_vtype]});
      subrels[etype] = UnitGraph::CreateFromCOO(
          relgraph->NumVertexTypes(),
          graph->NumVertices(src_vtype),
          graph->NumVertices(dst_vtype),
          earr.src,
          earr.dst);
      induced_edges[etype] = earr.id;
    }
  }
  HeteroSubgraph ret;
125
  ret.graph = CreateHeteroGraph(graph->meta_graph(), subrels, graph->NumVerticesPerType());
126
127
128
129
  ret.induced_edges = std::move(induced_edges);
  return ret;
}

130
131
132
133
134
135
136
137
138
HeteroSubgraph OutEdgeGraph(
    const HeteroGraphPtr graph, const std::vector<IdArray>& vids, bool relabel_nodes) {
  if (relabel_nodes) {
    return OutEdgeGraphRelabelNodes(graph, vids);
  } else {
    return OutEdgeGraphNoRelabelNodes(graph, vids);
  }
}

139
}  // namespace dgl