lda_model.py 15.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2021 Yifei Ma
# with references from "sklearn.decomposition.LatentDirichletAllocation"
# with the following original authors:
# * Chyi-Kwei Yau (the said scikit-learn implementation)
# * Matthew D. Hoffman (original onlineldavb implementation)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


20
21
22
23
24
25
26
27
28
29
30
import collections
import functools
import io
import os
import warnings

import numpy as np
import scipy as sp
import torch

import dgl
31

yifeim's avatar
yifeim committed
32
33
34
35
36
37
38
39
try:
    from functools import cached_property
except ImportError:
    try:
        from backports.cached_property import cached_property
    except ImportError:
        warnings.warn("cached_property not found - using property instead")
        cached_property = property
40
41


yifeim's avatar
yifeim committed
42
43
44
45
class EdgeData:
    def __init__(self, src_data, dst_data):
        self.src_data = src_data
        self.dst_data = dst_data
46

yifeim's avatar
yifeim committed
47
48
    @property
    def loglike(self):
49
        return (self.src_data["Elog"] + self.dst_data["Elog"]).logsumexp(1)
50

yifeim's avatar
yifeim committed
51
52
53
    @property
    def phi(self):
        return (
54
55
56
            self.src_data["Elog"]
            + self.dst_data["Elog"]
            - self.loglike.unsqueeze(1)
yifeim's avatar
yifeim committed
57
        ).exp()
58

yifeim's avatar
yifeim committed
59
60
    @property
    def expectation(self):
61
62
63
        return (
            self.src_data["expectation"] * self.dst_data["expectation"]
        ).sum(1)
64
65


yifeim's avatar
yifeim committed
66
67
68
69
70
class _Dirichlet:
    def __init__(self, prior, nphi, _chunksize=int(1e6)):
        self.prior = prior
        self.nphi = nphi
        self.device = nphi.device
71
72
73
74
75
76
77
        self._sum_by_parts = lambda map_fn: functools.reduce(
            torch.add,
            [
                map_fn(slice(i, min(i + _chunksize, nphi.shape[1]))).sum(1)
                for i in list(range(0, nphi.shape[1], _chunksize))
            ],
        )
yifeim's avatar
yifeim committed
78
79
80
81
82
83
84
85
86

    def _posterior(self, _ID=slice(None)):
        return self.prior + self.nphi[:, _ID]

    @cached_property
    def posterior_sum(self):
        return self.nphi.sum(1) + self.prior * self.nphi.shape[1]

    def _Elog(self, _ID=slice(None)):
87
88
89
        return torch.digamma(self._posterior(_ID)) - torch.digamma(
            self.posterior_sum.unsqueeze(1)
        )
yifeim's avatar
yifeim committed
90
91
92
93
94

    @cached_property
    def loglike(self):
        neg_evid = -self._sum_by_parts(
            lambda s: (self.nphi[:, s] * self._Elog(s))
95
        )
yifeim's avatar
yifeim committed
96
97
98
99
100
101
102

        prior = torch.as_tensor(self.prior).to(self.nphi)
        K = self.nphi.shape[1]
        log_B_prior = torch.lgamma(prior) * K - torch.lgamma(prior * K)

        log_B_posterior = self._sum_by_parts(
            lambda s: torch.lgamma(self._posterior(s))
103
        ) - torch.lgamma(self.posterior_sum)
yifeim's avatar
yifeim committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

        return neg_evid - log_B_prior + log_B_posterior

    @cached_property
    def n(self):
        return self.nphi.sum(1)

    @cached_property
    def cdf(self):
        cdf = self._posterior()
        torch.cumsum(cdf, 1, out=cdf)
        cdf /= cdf[:, -1:].clone()
        return cdf

    def _expectation(self, _ID=slice(None)):
        expectation = self._posterior(_ID)
        expectation /= self.posterior_sum.unsqueeze(1)
        return expectation

    @cached_property
    def Bayesian_gap(self):
125
        return 1.0 - self._sum_by_parts(lambda s: self._Elog(s).exp())
yifeim's avatar
yifeim committed
126

127
128
129
130
131
132
133
    _cached_properties = [
        "posterior_sum",
        "loglike",
        "n",
        "cdf",
        "Bayesian_gap",
    ]
yifeim's avatar
yifeim committed
134
135
136
137
138
139
140
141
142

    def clear_cache(self):
        for name in self._cached_properties:
            try:
                delattr(self, name)
            except AttributeError:
                pass

    def update(self, new, _ID=slice(None), rho=1):
143
        """inplace: old * (1-rho) + new * rho"""
yifeim's avatar
yifeim committed
144
145
146
        self.clear_cache()
        mean_change = (self.nphi[:, _ID] - new).abs().mean().tolist()

147
        self.nphi *= 1 - rho
yifeim's avatar
yifeim committed
148
149
150
151
152
        self.nphi[:, _ID] += new * rho
        return mean_change


class DocData(_Dirichlet):
153
154
    """nphi (n_docs by n_topics)"""

yifeim's avatar
yifeim committed
155
    def prepare_graph(self, G, key="Elog"):
156
        G.nodes["doc"].data[key] = getattr(self, "_" + key)().to(G.device)
yifeim's avatar
yifeim committed
157
158

    def update_from(self, G, mult):
159
        new = G.nodes["doc"].data["nphi"] * mult
yifeim's avatar
yifeim committed
160
161
162
163
        return self.update(new.to(self.device))


class _Distributed(collections.UserList):
164
165
    """split on dim=0 and store on multiple devices"""

yifeim's avatar
yifeim committed
166
167
168
169
    def __init__(self, prior, nphi):
        self.prior = prior
        self.nphi = nphi
        super().__init__([_Dirichlet(self.prior, nphi) for nphi in self.nphi])
170

yifeim's avatar
yifeim committed
171
172
173
    def split_device(self, other, dim=0):
        split_sections = [x.shape[0] for x in self.nphi]
        out = torch.split(other, split_sections, dim)
174
        return [y.to(x.device) for x, y in zip(self.nphi, out)]
175
176


yifeim's avatar
yifeim committed
177
class WordData(_Distributed):
178
179
    """distributed nphi (n_topics by n_words), transpose to/from graph nodes data"""

yifeim's avatar
yifeim committed
180
    def prepare_graph(self, G, key="Elog"):
181
182
        if "_ID" in G.nodes["word"].data:
            _ID = G.nodes["word"].data["_ID"]
yifeim's avatar
yifeim committed
183
184
        else:
            _ID = slice(None)
185

186
187
        out = [getattr(part, "_" + key)(_ID).to(G.device) for part in self]
        G.nodes["word"].data[key] = torch.cat(out).T
188

yifeim's avatar
yifeim committed
189
    def update_from(self, G, mult, rho):
190
        nphi = G.nodes["word"].data["nphi"].T * mult
yifeim's avatar
yifeim committed
191

192
193
        if "_ID" in G.nodes["word"].data:
            _ID = G.nodes["word"].data["_ID"]
yifeim's avatar
yifeim committed
194
195
        else:
            _ID = slice(None)
196

197
198
199
        mean_change = [
            x.update(y, _ID, rho) for x, y in zip(self, self.split_device(nphi))
        ]
yifeim's avatar
yifeim committed
200
        return np.mean(mean_change)
201
202


203
204
205
class Gamma(collections.namedtuple("Gamma", "concentration, rate")):
    """articulate the difference between torch gamma and numpy gamma"""

yifeim's avatar
yifeim committed
206
207
208
209
210
211
212
213
214
215
216
217
218
    @property
    def shape(self):
        return self.concentration

    @property
    def scale(self):
        return 1 / self.rate

    def sample(self, shape, device):
        return torch.distributions.gamma.Gamma(
            torch.as_tensor(self.concentration, device=device),
            torch.as_tensor(self.rate, device=device),
        ).sample(shape)
219
220
221


class LatentDirichletAllocation:
yifeim's avatar
yifeim committed
222
223
    """LDA model that works with a HeteroGraph with doc->word meta paths.
    The model alters the attributes of G arbitrarily.
224
225
    This is inspired by [1] and its corresponding scikit-learn implementation.

yifeim's avatar
yifeim committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    Inputs
    ---
    * G: a template graph or an integer showing n_words
    * n_components: latent feature dimension; automatically set priors if missing.
    * prior: parameters in the Dirichlet prior; default to 1/n_components and 1/n_words
    * rho: new_nphi = (1-rho)*old_nphi + rho*nphi; default to 1 for full gradients.
    * mult: multiplier for nphi-update; a large value effectively disables prior.
    * init: sklearn initializers (100.0, 100.0); the sample points concentrate around 1.0
    * device_list: accelerate word_data updates.

    Notes
    ---
    Some differences between this and sklearn.decomposition.LatentDirichletAllocation:
    * default word perplexity is normalized by training set instead of testing set.

241
242
243
244
245
246
247
    References
    ---
    [1] Matthew Hoffman, Francis Bach, David Blei. Online Learning for Latent
    Dirichlet Allocation. Advances in Neural Information Processing Systems 23
    (NIPS 2010).
    [2] Reactive LDA Library blogpost by Yingjie Miao for a similar Gibbs model
    """
248

249
    def __init__(
250
251
252
        self,
        n_words,
        n_components,
253
        prior=None,
yifeim's avatar
yifeim committed
254
        rho=1,
255
256
257
        mult={"doc": 1, "word": 1},
        init={"doc": (100.0, 100.0), "word": (100.0, 100.0)},
        device_list=["cpu"],
258
        verbose=True,
259
    ):
yifeim's avatar
yifeim committed
260
        self.n_words = n_words
261
262
263
        self.n_components = n_components

        if prior is None:
264
            prior = {"doc": 1.0 / n_components, "word": 1.0 / n_components}
265
266
        self.prior = prior

yifeim's avatar
yifeim committed
267
268
269
        self.rho = rho
        self.mult = mult
        self.init = init
270

yifeim's avatar
yifeim committed
271
        assert not isinstance(device_list, str), "plz wrap devices in a list"
272
        self.device_list = device_list[:n_components]  # avoid edge cases
273
274
        self.verbose = verbose

yifeim's avatar
yifeim committed
275
276
277
278
        self._init_word_data()

    def _init_word_data(self):
        split_sections = np.diff(
279
280
281
            np.linspace(0, self.n_components, len(self.device_list) + 1).astype(
                int
            )
yifeim's avatar
yifeim committed
282
283
        )
        word_nphi = [
284
            Gamma(*self.init["word"]).sample((s, self.n_words), device)
yifeim's avatar
yifeim committed
285
286
            for s, device in zip(split_sections, self.device_list)
        ]
287
        self.word_data = WordData(self.prior["word"], word_nphi)
yifeim's avatar
yifeim committed
288
289

    def _init_doc_data(self, n_docs, device):
290
291
292
293
        doc_nphi = Gamma(*self.init["doc"]).sample(
            (n_docs, self.n_components), device
        )
        return DocData(self.prior["doc"], doc_nphi)
yifeim's avatar
yifeim committed
294
295
296
297

    def save(self, f):
        for w in self.word_data:
            w.clear_cache()
298
299
300
301
302
303
304
305
306
307
        torch.save(
            {
                "prior": self.prior,
                "rho": self.rho,
                "mult": self.mult,
                "init": self.init,
                "word_data": [part.nphi for part in self.word_data],
            },
            f,
        )
308

yifeim's avatar
yifeim committed
309
310
311
312
313
    def _prepare_graph(self, G, doc_data, key="Elog"):
        doc_data.prepare_graph(G, key)
        self.word_data.prepare_graph(G, key)

    def _e_step(self, G, doc_data=None, mean_change_tol=1e-3, max_iters=100):
314
        """_e_step implements doc data sampling until convergence or max_iters"""
yifeim's avatar
yifeim committed
315
        if doc_data is None:
316
            doc_data = self._init_doc_data(G.num_nodes("doc"), G.device)
317

318
        G_rev = G.reverse()  # word -> doc
yifeim's avatar
yifeim committed
319
        self.word_data.prepare_graph(G_rev)
320
321

        for i in range(max_iters):
yifeim's avatar
yifeim committed
322
323
            doc_data.prepare_graph(G_rev)
            G_rev.update_all(
324
325
                lambda edges: {"phi": EdgeData(edges.src, edges.dst).phi},
                dgl.function.sum("phi", "nphi"),
yifeim's avatar
yifeim committed
326
            )
327
            mean_change = doc_data.update_from(G_rev, self.mult["doc"])
328
329
            if mean_change < mean_change_tol:
                break
yifeim's avatar
yifeim committed
330

331
        if self.verbose:
332
333
334
335
            print(
                f"e-step num_iters={i+1} with mean_change={mean_change:.4f}, "
                f"perplexity={self.perplexity(G, doc_data):.4f}"
            )
yifeim's avatar
yifeim committed
336

337
338
339
340
        return doc_data

    transform = _e_step

yifeim's avatar
yifeim committed
341
342
343
    def predict(self, doc_data):
        pred_scores = [
            # d_exp @ w._expectation()
344
345
346
            (lambda x: x @ w.nphi + x.sum(1, keepdims=True) * w.prior)(
                d_exp / w.posterior_sum.unsqueeze(0)
            )
yifeim's avatar
yifeim committed
347
348
            for (d_exp, w) in zip(
                self.word_data.split_device(doc_data._expectation(), dim=1),
349
350
                self.word_data,
            )
yifeim's avatar
yifeim committed
351
352
353
354
355
356
357
        ]
        x = torch.zeros_like(pred_scores[0], device=doc_data.device)
        for p in pred_scores:
            x += p.to(x.device)
        return x

    def sample(self, doc_data, num_samples):
358
        """draw independent words and return the marginal probabilities,
yifeim's avatar
yifeim committed
359
        i.e., the expectations in Dirichlet distributions.
360
        """
361

yifeim's avatar
yifeim committed
362
363
364
365
366
367
        def fn(cdf):
            u = torch.rand(cdf.shape[0], num_samples, device=cdf.device)
            return torch.searchsorted(cdf, u).to(doc_data.device)

        topic_ids = fn(doc_data.cdf)
        word_ids = torch.cat([fn(part.cdf) for part in self.word_data])
368
369
370
        ids = torch.gather(
            word_ids, 0, topic_ids
        )  # pick components by topic_ids
yifeim's avatar
yifeim committed
371
372

        # compute expectation scores on sampled ids
373
374
375
376
377
378
379
380
        src_ids = (
            torch.arange(ids.shape[0], dtype=ids.dtype, device=ids.device)
            .reshape((-1, 1))
            .expand(ids.shape)
        )
        unique_ids, inverse_ids = torch.unique(
            ids, sorted=False, return_inverse=True
        )
yifeim's avatar
yifeim committed
381

382
383
384
385
        G = dgl.heterograph(
            {("doc", "", "word"): (src_ids.ravel(), inverse_ids.ravel())}
        )
        G.nodes["word"].data["_ID"] = unique_ids
yifeim's avatar
yifeim committed
386
        self._prepare_graph(G, doc_data, "expectation")
387
388
389
390
        G.apply_edges(
            lambda e: {"expectation": EdgeData(e.src, e.dst).expectation}
        )
        expectation = G.edata.pop("expectation").reshape(ids.shape)
yifeim's avatar
yifeim committed
391
392
393
394
395
396
397
398
399
400

        return ids, expectation

    def _m_step(self, G, doc_data):
        """_m_step implements word data sampling and stores word_z stats.
        mean_change is in the sense of full graph with rho=1.
        """
        G = G.clone()
        self._prepare_graph(G, doc_data)
        G.update_all(
401
402
            lambda edges: {"phi": EdgeData(edges.src, edges.dst).phi},
            dgl.function.sum("phi", "nphi"),
403
        )
yifeim's avatar
yifeim committed
404
        self._last_mean_change = self.word_data.update_from(
405
406
            G, self.mult["word"], self.rho
        )
yifeim's avatar
yifeim committed
407
408
409

        if self.verbose:
            print(f"m-step mean_change={self._last_mean_change:.4f}, ", end="")
410
411
412
            Bayesian_gap = np.mean(
                [part.Bayesian_gap.mean().tolist() for part in self.word_data]
            )
yifeim's avatar
yifeim committed
413
414
            print(f"Bayesian_gap={Bayesian_gap:.4f}")

415
    def partial_fit(self, G):
yifeim's avatar
yifeim committed
416
417
        doc_data = self._e_step(G)
        self._m_step(G, doc_data)
418
419
420
421
422
        return self

    def fit(self, G, mean_change_tol=1e-3, max_epochs=10):
        for i in range(max_epochs):
            if self.verbose:
yifeim's avatar
yifeim committed
423
424
425
426
                print(f"epoch {i+1}, ", end="")
            self.partial_fit(G)

            if self._last_mean_change < mean_change_tol:
427
428
429
                break
        return self

yifeim's avatar
yifeim committed
430
    def perplexity(self, G, doc_data=None):
431
432
433
        """ppl = exp{-sum[log(p(w1,...,wn|d))] / n}
        Follows Eq (15) in Hoffman et al., 2010.
        """
yifeim's avatar
yifeim committed
434
435
        if doc_data is None:
            doc_data = self._e_step(G)
436
437

        # compute E[log p(docs | theta, beta)]
yifeim's avatar
yifeim committed
438
439
        G = G.clone()
        self._prepare_graph(G, doc_data)
440
441
442
443
        G.apply_edges(
            lambda edges: {"loglike": EdgeData(edges.src, edges.dst).loglike}
        )
        edge_elbo = (G.edata["loglike"].sum() / G.num_edges()).tolist()
444
        if self.verbose:
445
            print(f"neg_elbo phi: {-edge_elbo:.3f}", end=" ")
446
447

        # compute E[log p(theta | alpha) - log q(theta | gamma)]
yifeim's avatar
yifeim committed
448
        doc_elbo = (doc_data.loglike.sum() / doc_data.n.sum()).tolist()
449
        if self.verbose:
450
            print(f"theta: {-doc_elbo:.3f}", end=" ")
451

yifeim's avatar
yifeim committed
452
453
454
        # compute E[log p(beta | eta) - log q(beta | lambda)]
        # The denominator n for extrapolation perplexity is undefined.
        # We use the train set, whereas sklearn uses the test set.
455
456
457
        word_elbo = sum(
            [part.loglike.sum().tolist() for part in self.word_data]
        ) / sum([part.n.sum().tolist() for part in self.word_data])
458
        if self.verbose:
459
            print(f"beta: {-word_elbo:.3f}")
460

yifeim's avatar
yifeim committed
461
        ppl = np.exp(-edge_elbo - doc_elbo - word_elbo)
462
        if G.num_edges() > 0 and np.isnan(ppl):
yifeim's avatar
yifeim committed
463
464
465
466
467
468
            warnings.warn("numerical issue in perplexity")
        return ppl


def doc_subgraph(G, doc_ids):
    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
469
470
471
    _, _, (block,) = sampler.sample(
        G.reverse(), {"doc": torch.as_tensor(doc_ids)}
    )
yifeim's avatar
yifeim committed
472
    B = dgl.DGLHeteroGraph(
473
        block._graph, ["_", "word", "doc", "_"], block.etypes
yifeim's avatar
yifeim committed
474
    ).reverse()
475
    B.nodes["word"].data["_ID"] = block.nodes["word"].data["_ID"]
yifeim's avatar
yifeim committed
476
    return B
477
478


479
480
481
482
483
if __name__ == "__main__":
    print("Testing LatentDirichletAllocation ...")
    G = dgl.heterograph(
        {("doc", "", "word"): [(0, 0), (1, 3)]}, {"doc": 2, "word": 5}
    )
yifeim's avatar
yifeim committed
484
    model = LatentDirichletAllocation(n_words=5, n_components=10, verbose=False)
485
486
    model.fit(G)
    model.transform(G)
yifeim's avatar
yifeim committed
487
488
489
    model.predict(model.transform(G))
    if hasattr(torch, "searchsorted"):
        model.sample(model.transform(G), 3)
490
    model.perplexity(G)
yifeim's avatar
yifeim committed
491
492
493
494
495
496
497
498
499
500

    for doc_id in range(2):
        B = doc_subgraph(G, [doc_id])
        model.partial_fit(B)

    with io.BytesIO() as f:
        model.save(f)
        f.seek(0)
        print(torch.load(f))

501
    print("Testing LatentDirichletAllocation passed!")