rgnn.py 3.37 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# This script was taken from:
# https://github.com/mlcommons/training/blob/master/graph_neural_network/dataset.py

import torch
import torch.nn.functional as F

from torch_geometric.nn import HeteroConv, GATConv, GCNConv, SAGEConv
from torch_geometric.utils import trim_to_layer


class RGNN(torch.nn.Module):
    r"""[Relational GNN model](https://arxiv.org/abs/1703.06103).

    Args:
      etypes: edge types.
      in_dim: input size.
      h_dim: Dimension of hidden layer.
      out_dim: Output dimension.
      num_layers: Number of conv layers.
      dropout: Dropout probability for hidden layers.
      model: "rsage" or "rgat".
      heads: Number of multi-head-attentions for GAT.
      node_type: The predict node type for node classification.

    """

    def __init__(
        self,
        etypes,
        in_dim,
        h_dim,
        out_dim,
        num_layers=2,
        dropout=0.2,
        model="rgat",
        heads=4,
        node_type=None,
        with_trim=False,
    ):
        super().__init__()
        self.node_type = node_type
        if node_type is not None:
            self.lin = torch.nn.Linear(h_dim, out_dim)

        self.convs = torch.nn.ModuleList()
        for i in range(num_layers):
            in_dim = in_dim if i == 0 else h_dim
            h_dim = out_dim if (
                i == (
                    num_layers -
                    1) and node_type is None) else h_dim
            if model == "rsage":
                self.convs.append(
                    HeteroConv(
                        {
                            etype: SAGEConv(in_dim, h_dim, root_weight=False)
                            for etype in etypes
                        }
                    )
                )
            elif model == "rgat":
                self.convs.append(
                    HeteroConv(
                        {
                            etype: GATConv(
                                in_dim,
                                h_dim // heads,
                                heads=heads,
                                add_self_loops=False,
                            )
                            for etype in etypes
                        }
                    )
                )
        self.dropout = torch.nn.Dropout(dropout)
        self.with_trim = with_trim

    def forward(
        self,
        x_dict,
        edge_index_dict,
        num_sampled_edges_dict=None,
        num_sampled_nodes_dict=None,
    ):
        for i, conv in enumerate(self.convs):
            if self.with_trim:
                x_dict, edge_index_dict, _ = trim_to_layer(
                    layer=i,
                    num_sampled_nodes_per_hop=num_sampled_nodes_dict,
                    num_sampled_edges_per_hop=num_sampled_edges_dict,
                    x=x_dict,
                    edge_index=edge_index_dict,
                )
            for key in list(edge_index_dict.keys()):
                if key[0] not in x_dict or key[-1] not in x_dict:
                    del edge_index_dict[key]

            x_dict = conv(x_dict, edge_index_dict)
            if i != len(self.convs) - 1:
                x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()}
                x_dict = {key: self.dropout(x) for key, x in x_dict.items()}
        if hasattr(self, "lin"):  # for node classification
            return self.lin(x_dict[self.node_type])
        else:
            return x_dict