"vscode:/vscode.git/clone" did not exist on "770d3b3c295ead9b6131e24b6dea7db6e89e3f10"
builder.py 5.01 KB
Newer Older
1
2
"""Graph builder from pandas dataframes"""
from collections import namedtuple
3
4
5
6
7
8
9

from pandas.api.types import (
    is_categorical,
    is_categorical_dtype,
    is_numeric_dtype,
)

10
11
import dgl

12
13
__all__ = ["PandasGraphBuilder"]

14
15
16

def _series_to_tensor(series):
    if is_categorical(series):
17
18
        return torch.LongTensor(series.cat.codes.values.astype("int64"))
    else:  # numeric
19
20
        return torch.FloatTensor(series.values)

21

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class PandasGraphBuilder(object):
    """Creates a heterogeneous graph from multiple pandas dataframes.

    Examples
    --------
    Let's say we have the following three pandas dataframes:

    User table ``users``:

    ===========  ===========  =======
    ``user_id``  ``country``  ``age``
    ===========  ===========  =======
    XYZZY        U.S.         25
    FOO          China        24
    BAR          China        23
    ===========  ===========  =======

    Game table ``games``:

    ===========  =========  ==============  ==================
    ``game_id``  ``title``  ``is_sandbox``  ``is_multiplayer``
    ===========  =========  ==============  ==================
    1            Minecraft  True            True
    2            Tetris 99  False           True
    ===========  =========  ==============  ==================

    Play relationship table ``plays``:

    ===========  ===========  =========
    ``user_id``  ``game_id``  ``hours``
    ===========  ===========  =========
    XYZZY        1            24
    FOO          1            20
    FOO          2            16
    BAR          2            28
    ===========  ===========  =========

    One could then create a bidirectional bipartite graph as follows:
    >>> builder = PandasGraphBuilder()
    >>> builder.add_entities(users, 'user_id', 'user')
    >>> builder.add_entities(games, 'game_id', 'game')
    >>> builder.add_binary_relations(plays, 'user_id', 'game_id', 'plays')
    >>> builder.add_binary_relations(plays, 'game_id', 'user_id', 'played-by')
    >>> g = builder.build()
66
    >>> g.num_nodes('user')
67
    3
68
    >>> g.num_edges('plays')
69
70
    4
    """
71

72
73
74
75
    def __init__(self):
        self.entity_tables = {}
        self.relation_tables = {}

76
77
78
79
80
81
82
        self.entity_pk_to_name = (
            {}
        )  # mapping from primary key name to entity name
        self.entity_pk = {}  # mapping from entity name to primary key
        self.entity_key_map = (
            {}
        )  # mapping from entity names to primary key values
83
84
85
        self.num_nodes_per_type = {}
        self.edges_per_relation = {}
        self.relation_name_to_etype = {}
86
87
88
89
        self.relation_src_key = {}  # mapping from relation name to source key
        self.relation_dst_key = (
            {}
        )  # mapping from relation name to destination key
90
91

    def add_entities(self, entity_table, primary_key, name):
92
        entities = entity_table[primary_key].astype("category")
93
        if not (entities.value_counts() == 1).all():
94
95
96
            raise ValueError(
                "Different entity with the same primary key detected."
            )
97
        # preserve the category order in the original entity table
98
99
100
        entities = entities.cat.reorder_categories(
            entity_table[primary_key].values
        )
101
102
103
104
105
106
107

        self.entity_pk_to_name[primary_key] = name
        self.entity_pk[name] = primary_key
        self.num_nodes_per_type[name] = entity_table.shape[0]
        self.entity_key_map[name] = entities
        self.entity_tables[name] = entity_table

108
109
110
111
    def add_binary_relations(
        self, relation_table, source_key, destination_key, name
    ):
        src = relation_table[source_key].astype("category")
112
        src = src.cat.set_categories(
113
114
115
116
117
            self.entity_key_map[
                self.entity_pk_to_name[source_key]
            ].cat.categories
        )
        dst = relation_table[destination_key].astype("category")
118
        dst = dst.cat.set_categories(
119
120
121
122
            self.entity_key_map[
                self.entity_pk_to_name[destination_key]
            ].cat.categories
        )
123
124
        if src.isnull().any():
            raise ValueError(
125
126
127
                "Some source entities in relation %s do not exist in entity %s."
                % (name, source_key)
            )
128
129
        if dst.isnull().any():
            raise ValueError(
130
131
132
                "Some destination entities in relation %s do not exist in entity %s."
                % (name, destination_key)
            )
133
134
135
136
137

        srctype = self.entity_pk_to_name[source_key]
        dsttype = self.entity_pk_to_name[destination_key]
        etype = (srctype, name, dsttype)
        self.relation_name_to_etype[name] = etype
138
139
140
141
        self.edges_per_relation[etype] = (
            src.cat.codes.values.astype("int64"),
            dst.cat.codes.values.astype("int64"),
        )
142
143
144
145
146
147
        self.relation_tables[name] = relation_table
        self.relation_src_key[name] = source_key
        self.relation_dst_key[name] = destination_key

    def build(self):
        # Create heterograph
148
149
150
        graph = dgl.heterograph(
            self.edges_per_relation, self.num_nodes_per_type
        )
151
        return graph