"git@developer.sourcefind.cn:change/sglang.git" did not exist on "016fd2512788f029d6d9a7c866c12dadb34cde9c"
data.rst 3.77 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
Prepare Data
============

In this section, we will prepare the data for the Graphormer model introduced before. We can use any dataset containing :class:`~dgl.DGLGraph` objects and standard PyTorch dataloader to feed the data to the model. The key is to define a collate function to group features of multiple graphs into batches. We show an example of the collate function as follows:


.. code:: python
    def collate(graphs):
        # compute shortest path features, can be done in advance
        for g in graphs:
            spd, path = dgl.shortest_dist(g, root=None, return_paths=True)
            g.ndata["spd"] = spd
            g.ndata["path"] = path

        num_graphs = len(graphs)
        num_nodes = [g.num_nodes() for g in graphs]
        max_num_nodes = max(num_nodes)

        attn_mask = th.zeros(num_graphs, max_num_nodes, max_num_nodes)
        node_feat = []
        in_degree, out_degree = [], []
        path_data = []
        # Since shortest_dist returns -1 for unreachable node pairs and padded
        # nodes are unreachable to others, distance relevant to padded nodes
        # use -1 padding as well.
        dist = -th.ones(
            (num_graphs, max_num_nodes, max_num_nodes), dtype=th.long
        )

        for i in range(num_graphs):
            # A binary mask where invalid positions are indicated by True.
            # Avoid the case where all positions are invalid.
            attn_mask[i, :, num_nodes[i] + 1 :] = 1

            # +1 to distinguish padded non-existing nodes from real nodes
            node_feat.append(graphs[i].ndata["feat"] + 1)

            # 0 for padding
            in_degree.append(
                th.clamp(graphs[i].in_degrees() + 1, min=0, max=512)
            )
            out_degree.append(
                th.clamp(graphs[i].out_degrees() + 1, min=0, max=512)
            )

            # Path padding to make all paths to the same length "max_len".
            path = graphs[i].ndata["path"]
            path_len = path.size(dim=2)
            # shape of shortest_path: [n, n, max_len]
            max_len = 5
            if path_len >= max_len:
                shortest_path = path[:, :, :max_len]
            else:
                p1d = (0, max_len - path_len)
                # Use the same -1 padding as shortest_dist for
                # invalid edge IDs.
                shortest_path = th.nn.functional.pad(path, p1d, "constant", -1)
            pad_num_nodes = max_num_nodes - num_nodes[i]
            p3d = (0, 0, 0, pad_num_nodes, 0, pad_num_nodes)
            shortest_path = th.nn.functional.pad(shortest_path, p3d, "constant", -1)
            # +1 to distinguish padded non-existing edges from real edges
            edata = graphs[i].edata["feat"] + 1

            # shortest_dist pads non-existing edges (at the end of shortest
            # paths) with edge IDs -1, and th.zeros(1, edata.shape[1]) stands
            # for all padded edge features.
            edata = th.cat(
                (edata, th.zeros(1, edata.shape[1]).to(edata.device)), dim=0
            )
            path_data.append(edata[shortest_path])

            dist[i, : num_nodes[i], : num_nodes[i]] = graphs[i].ndata["spd"]

        # node feat padding
        node_feat = th.nn.utils.rnn.pad_sequence(node_feat, batch_first=True)

        # degree padding
        in_degree = th.nn.utils.rnn.pad_sequence(in_degree, batch_first=True)
        out_degree = th.nn.utils.rnn.pad_sequence(out_degree, batch_first=True)

        return (
            node_feat,
            in_degree,
            out_degree,
            attn_mask,
            th.stack(path_data),
            dist,
        )

In this example, we also omit details like the addition of a virtual node. For more details, please refer to the `Graphormer example <https://github.com/dmlc/dgl/tree/master/examples/core/Graphormer>`_.