data_loader.py 9.79 KB
Newer Older
KounianhuaDu's avatar
KounianhuaDu committed
1
from collections import defaultdict as ddict
2
3
4

import numpy as np
import torch
KounianhuaDu's avatar
KounianhuaDu committed
5
from ordered_set import OrderedSet
6
7
8
9
from torch.utils.data import DataLoader, Dataset

import dgl

KounianhuaDu's avatar
KounianhuaDu committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23

class TrainDataset(Dataset):
    """
    Training Dataset class.
    Parameters
    ----------
    triples: The triples used for training the model
    num_ent: Number of entities in the knowledge graph
    lbl_smooth: Label smoothing

    Returns
    -------
    A training Dataset class instance used by DataLoader
    """
24

KounianhuaDu's avatar
KounianhuaDu committed
25
26
27
28
29
30
31
32
33
34
35
    def __init__(self, triples, num_ent, lbl_smooth):
        self.triples = triples
        self.num_ent = num_ent
        self.lbl_smooth = lbl_smooth
        self.entities = np.arange(self.num_ent, dtype=np.int32)

    def __len__(self):
        return len(self.triples)

    def __getitem__(self, idx):
        ele = self.triples[idx]
36
        triple, label = torch.LongTensor(ele["triple"]), np.int32(ele["label"])
KounianhuaDu's avatar
KounianhuaDu committed
37
        trp_label = self.get_label(label)
38
        # label smoothing
KounianhuaDu's avatar
KounianhuaDu committed
39
        if self.lbl_smooth != 0.0:
40
41
42
            trp_label = (1.0 - self.lbl_smooth) * trp_label + (
                1.0 / self.num_ent
            )
KounianhuaDu's avatar
KounianhuaDu committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56

        return triple, trp_label

    @staticmethod
    def collate_fn(data):
        triples = []
        labels = []
        for triple, label in data:
            triples.append(triple)
            labels.append(label)
        triple = torch.stack(triples, dim=0)
        trp_label = torch.stack(labels, dim=0)
        return triple, trp_label

57
    # for edges that exist in the graph, the entry is 1.0, otherwise the entry is 0.0
KounianhuaDu's avatar
KounianhuaDu committed
58
59
    def get_label(self, label):
        y = np.zeros([self.num_ent], dtype=np.float32)
60
        for e2 in label:
KounianhuaDu's avatar
KounianhuaDu committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
            y[e2] = 1.0
        return torch.FloatTensor(y)


class TestDataset(Dataset):
    """
    Evaluation Dataset class.
    Parameters
    ----------
    triples: The triples used for evaluating the model
    num_ent: Number of entities in the knowledge graph

    Returns
    -------
    An evaluation Dataset class instance used by DataLoader for model evaluation
    """
77

KounianhuaDu's avatar
KounianhuaDu committed
78
79
80
81
82
83
84
85
86
    def __init__(self, triples, num_ent):
        self.triples = triples
        self.num_ent = num_ent

    def __len__(self):
        return len(self.triples)

    def __getitem__(self, idx):
        ele = self.triples[idx]
87
        triple, label = torch.LongTensor(ele["triple"]), np.int32(ele["label"])
KounianhuaDu's avatar
KounianhuaDu committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        label = self.get_label(label)

        return triple, label

    @staticmethod
    def collate_fn(data):
        triples = []
        labels = []
        for triple, label in data:
            triples.append(triple)
            labels.append(label)
        triple = torch.stack(triples, dim=0)
        label = torch.stack(labels, dim=0)
        return triple, label

103
    # for edges that exist in the graph, the entry is 1.0, otherwise the entry is 0.0
KounianhuaDu's avatar
KounianhuaDu committed
104
105
    def get_label(self, label):
        y = np.zeros([self.num_ent], dtype=np.float32)
106
        for e2 in label:
KounianhuaDu's avatar
KounianhuaDu committed
107
108
109
110
111
112
113
            y[e2] = 1.0
        return torch.FloatTensor(y)


class Data(object):
    def __init__(self, dataset, lbl_smooth, num_workers, batch_size):
        """
114
        Reading in raw triples and converts it into a standard format.
KounianhuaDu's avatar
KounianhuaDu committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        Parameters
        ----------
        dataset:           The name of the dataset
        lbl_smooth:        Label smoothing
        num_workers:       Number of workers of dataloaders
        batch_size:        Batch size of dataloaders

        Returns
        -------
        self.ent2id:            Entity to unique identifier mapping
        self.rel2id:            Relation to unique identifier mapping
        self.id2ent:            Inverse mapping of self.ent2id
        self.id2rel:            Inverse mapping of self.rel2id
        self.num_ent:           Number of entities in the knowledge graph
        self.num_rel:           Number of relations in the knowledge graph

        self.g:                 The dgl graph constucted from the edges in the traing set and all the entities in the knowledge graph
        self.data['train']:     Stores the triples corresponding to training dataset
        self.data['valid']:     Stores the triples corresponding to validation dataset
        self.data['test']:      Stores the triples corresponding to test dataset
        self.data_iter:		The dataloader for different data splits
        """
        self.dataset = dataset
        self.lbl_smooth = lbl_smooth
        self.num_workers = num_workers
        self.batch_size = batch_size

142
        # read in raw data and get mappings
KounianhuaDu's avatar
KounianhuaDu committed
143
        ent_set, rel_set = OrderedSet(), OrderedSet()
144
145
146
        for split in ["train", "test", "valid"]:
            for line in open("./{}/{}.txt".format(self.dataset, split)):
                sub, rel, obj = map(str.lower, line.strip().split("\t"))
KounianhuaDu's avatar
KounianhuaDu committed
147
148
149
150
151
152
                ent_set.add(sub)
                rel_set.add(rel)
                ent_set.add(obj)

        self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)}
        self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)}
153
154
155
156
157
158
        self.rel2id.update(
            {
                rel + "_reverse": idx + len(self.rel2id)
                for idx, rel in enumerate(rel_set)
            }
        )
KounianhuaDu's avatar
KounianhuaDu committed
159
160
161
162
163
164
165

        self.id2ent = {idx: ent for ent, idx in self.ent2id.items()}
        self.id2rel = {idx: rel for rel, idx in self.rel2id.items()}

        self.num_ent = len(self.ent2id)
        self.num_rel = len(self.rel2id) // 2

166
167
168
169
170
171
172
        # read in ids of subjects, relations, and objects for train/test/valid
        self.data = ddict(list)  # stores the triples
        sr2o = ddict(
            set
        )  # The key of sr20 is (subject, relation), and the items are all the successors following (subject, relation)
        src = []
        dst = []
KounianhuaDu's avatar
KounianhuaDu committed
173
174
175
176
177
        rels = []
        inver_src = []
        inver_dst = []
        inver_rels = []

178
179
180
181
182
183
184
185
        for split in ["train", "test", "valid"]:
            for line in open("./{}/{}.txt".format(self.dataset, split)):
                sub, rel, obj = map(str.lower, line.strip().split("\t"))
                sub_id, rel_id, obj_id = (
                    self.ent2id[sub],
                    self.rel2id[rel],
                    self.ent2id[obj],
                )
KounianhuaDu's avatar
KounianhuaDu committed
186
187
                self.data[split].append((sub_id, rel_id, obj_id))

188
                if split == "train":
KounianhuaDu's avatar
KounianhuaDu committed
189
                    sr2o[(sub_id, rel_id)].add(obj_id)
190
191
192
                    sr2o[(obj_id, rel_id + self.num_rel)].add(
                        sub_id
                    )  # append the reversed edges
KounianhuaDu's avatar
KounianhuaDu committed
193
194
195
196
197
                    src.append(sub_id)
                    dst.append(obj_id)
                    rels.append(rel_id)
                    inver_src.append(obj_id)
                    inver_dst.append(sub_id)
198
                    inver_rels.append(rel_id + self.num_rel)
KounianhuaDu's avatar
KounianhuaDu committed
199

200
        # construct dgl graph
KounianhuaDu's avatar
KounianhuaDu committed
201
202
203
204
        src = src + inver_src
        dst = dst + inver_dst
        rels = rels + inver_rels
        self.g = dgl.graph((src, dst), num_nodes=self.num_ent)
205
206
207
208
209
210
211
212
213
214
215
216
217
        self.g.edata["etype"] = torch.Tensor(rels).long()

        # identify in and out edges
        in_edges_mask = [True] * (self.g.num_edges() // 2) + [False] * (
            self.g.num_edges() // 2
        )
        out_edges_mask = [False] * (self.g.num_edges() // 2) + [True] * (
            self.g.num_edges() // 2
        )
        self.g.edata["in_edges_mask"] = torch.Tensor(in_edges_mask)
        self.g.edata["out_edges_mask"] = torch.Tensor(out_edges_mask)

        # Prepare train/valid/test data
KounianhuaDu's avatar
KounianhuaDu committed
218
        self.data = dict(self.data)
219
220
221
        self.sr2o = {
            k: list(v) for k, v in sr2o.items()
        }  # store only the train data
KounianhuaDu's avatar
KounianhuaDu committed
222

223
        for split in ["test", "valid"]:
KounianhuaDu's avatar
KounianhuaDu committed
224
225
            for sub, rel, obj in self.data[split]:
                sr2o[(sub, rel)].add(obj)
226
                sr2o[(obj, rel + self.num_rel)].add(sub)
KounianhuaDu's avatar
KounianhuaDu committed
227

228
229
230
231
        self.sr2o_all = {
            k: list(v) for k, v in sr2o.items()
        }  # store all the data
        self.triples = ddict(list)
KounianhuaDu's avatar
KounianhuaDu committed
232
233

        for (sub, rel), obj in self.sr2o.items():
234
235
236
            self.triples["train"].append(
                {"triple": (sub, rel, -1), "label": self.sr2o[(sub, rel)]}
            )
KounianhuaDu's avatar
KounianhuaDu committed
237

238
        for split in ["test", "valid"]:
KounianhuaDu's avatar
KounianhuaDu committed
239
240
            for sub, rel, obj in self.data[split]:
                rel_inv = rel + self.num_rel
241
242
243
244
245
246
247
248
249
250
251
252
                self.triples["{}_{}".format(split, "tail")].append(
                    {
                        "triple": (sub, rel, obj),
                        "label": self.sr2o_all[(sub, rel)],
                    }
                )
                self.triples["{}_{}".format(split, "head")].append(
                    {
                        "triple": (obj, rel_inv, sub),
                        "label": self.sr2o_all[(obj, rel_inv)],
                    }
                )
KounianhuaDu's avatar
KounianhuaDu committed
253
254
255
256

        self.triples = dict(self.triples)

        def get_train_data_loader(split, batch_size, shuffle=True):
257
258
259
260
261
262
263
264
265
            return DataLoader(
                TrainDataset(
                    self.triples[split], self.num_ent, self.lbl_smooth
                ),
                batch_size=batch_size,
                shuffle=shuffle,
                num_workers=max(0, self.num_workers),
                collate_fn=TrainDataset.collate_fn,
            )
KounianhuaDu's avatar
KounianhuaDu committed
266
267

        def get_test_data_loader(split, batch_size, shuffle=True):
268
269
270
271
272
273
274
275
276
            return DataLoader(
                TestDataset(self.triples[split], self.num_ent),
                batch_size=batch_size,
                shuffle=shuffle,
                num_workers=max(0, self.num_workers),
                collate_fn=TestDataset.collate_fn,
            )

        # train/valid/test dataloaders
KounianhuaDu's avatar
KounianhuaDu committed
277
        self.data_iter = {
278
279
280
281
282
            "train": get_train_data_loader("train", self.batch_size),
            "valid_head": get_test_data_loader("valid_head", self.batch_size),
            "valid_tail": get_test_data_loader("valid_tail", self.batch_size),
            "test_head": get_test_data_loader("test_head", self.batch_size),
            "test_tail": get_test_data_loader("test_tail", self.batch_size),
KounianhuaDu's avatar
KounianhuaDu committed
283
        }