Unverified Commit fe6e01ad authored by yifeim's avatar yifeim Committed by GitHub
Browse files

[Model] Lda subgraph (#3206)



* add word_ids and simplify

* simplify

* add word_ids to be removed later

* remove word_ids

* seems to work

* tweak

* transpose word_z

* add word_ids example

* check api compatibility

* improve compatibility

* update doc

* tweak verbose

* restore word_z layout; tweak

* tweak

* tweak doc

* word_cT

* use log_weight and some other tweaks

* rewrite README

* update equations

* rewrite for clarity and pass tests

* tweak

* bugfix import

* fix unit test

* fix mult to be the same as old versions

* tweak

* could be a bugfix

* 0/0=nan

* add doc_subgraph utility function

* minor cache optimization

* minor cache tweak

* add environmental variable to trade cache speed for memory

* update README

* tweak

* add sparse update pass unit test

* simplify sparse update

* improve low-memory efficiency

* tweak

* add sample expectation scores to allow resampling

* simplify

* update comment

* avoid edge cases

* bugfix pred scores

* simplify

* add save function
Co-authored-by: default avatarYifei Ma <yifeim@amazon.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 9c41e97c
...@@ -8,64 +8,137 @@ There is no back-propagation, because gradient descent is typically considered ...@@ -8,64 +8,137 @@ There is no back-propagation, because gradient descent is typically considered
inefficient on probability simplex. inefficient on probability simplex.
On the provided small-scale example on 20 news groups dataset, our DGL-LDA model runs On the provided small-scale example on 20 news groups dataset, our DGL-LDA model runs
50% faster on GPU than sklearn model without joblib parallel. 50% faster on GPU than sklearn model without joblib parallel.
For larger graphs, thanks to subgraph sampling and low-memory implementation, we may fit 100 million unique words with 256 topic dimensions on a large multi-gpu machine.
(The runtime memory is often less than 2x of parameter storage.)
Key equations Key equations
--- ---
* The corpus is generated by hierarchical Bayes: document(d) -> latent topic(z) -> word(w) <!-- https://editor.codecogs.com/ -->
* All positions in the same document have shared topic distribution θ_d~Dir(α)
* All positions of the same topic have shared word distribution β_z~Dir(η) Let k be the topic index variable with one-hot encoded vector representation z. The rest of the variables are:
* The words in the same document / topic are correlated.
| | z_d\~p(θ_d) | w_k\~p(β_k) | z_dw\~q(ϕ_dw) |
**MAP** |-------------|-------------|-------------|---------------|
| Prior | Dir(α) | Dir(η) | (n/a) |
A simplified MAP model is just a non-conjugate model with an inner summation to integrate out the latent topic variable: | Posterior | Dir(γ_d) | Dir(λ_k) | (n/a) |
<img src="https://latex.codecogs.com/gif.latex?p(G)=\prod_{(d,w)}\left(\sum_z\theta_{dz}\beta_{zw}\right)" title="map" />
We overload w with bold-symbol-w, which represents the entire observed document-world multi-graph. The difference is better shown in the original paper.
The main complications are that θ_d / β_z are shared in the same document / topic and the variables reside in a probability simplex.
One way to work around it is via expectation maximization **Multinomial PCA**
<img src="https://latex.codecogs.com/gif.latex?\log&space;p(G)&space;=\sum_{(d,w)}\log\left(\sum_z\theta_{dz}\beta_{zw}\right)&space;\geq\sum_{(d,w)}\mathbb{E}_q\log\left(\frac{\theta_{dz}\beta_{zw}}{q(z;\phi_{dw})}\right)" title="map-em" /> Multinomial PCA is a "latent allocation" model without the "Dirichlet".
Its data likelihood sums over the latent topic-index variable k,
* An explicit posterior is ϕ_dwz ∝ θ_dz * β_zw <img src="https://latex.codecogs.com/svg.image?\inline&space;p(w_{di}|\theta_d,\beta)=\sum_k\theta_{dk}\beta_{kw}"/>,
* E-step: find summary statistics with fractional membership where θ_d and β_k are shared within the same document and topic, respectively.
* M-step: set θ_d, β_z proportional to the summary statistics
* With an explicit posterior, the bound is tight. If we perform gradient descent, we may need additional steps to project the parameters to the probability simplices:
<img src="https://latex.codecogs.com/svg.image?\inline&space;\sum_k\theta_{dk}=1"/>
and
<img src="https://latex.codecogs.com/svg.image?\inline&space;\sum_w\beta_{kw}=1"/>.
Instead, a more efficient solution is to borrow ideas from evidence lower-bound (ELBO) decomposition:
<!--
\log p(w) \geq \mathcal{L}(w,\phi)
\stackrel{def}{=}
\mathbb{E}_q [\log p(w,z;\theta,\beta) - \log q(z;\phi)]
\\=
\mathbb{E}_q [\log p(w|z;\beta) + \log p(z;\theta) - \log q(z;\phi)]
\\=
\sum_{dwk}n_{dw}\phi_{dwk} [\log\beta_{kw} + \log \theta_{dk} - \log \phi_{dwk}]
-->
<img src="https://latex.codecogs.com/svg.image?\log&space;p(w)&space;\geq&space;\mathcal{L}(w,\phi)\stackrel{def}{=}\mathbb{E}_q&space;[\log&space;p(w,z;\theta,\beta)&space;-&space;\log&space;q(z;\phi)]\\=\mathbb{E}_q&space;[\log&space;p(w|z;\beta)&space;&plus;&space;\log&space;p(z;\theta)&space;-&space;\log&space;q(z;\phi)]\\=\sum_{dwk}n_{dw}\phi_{dwk}&space;[\log\beta_{kw}&space;&plus;&space;\log&space;\theta_{dk}&space;-&space;\log&space;\phi_{dwk}]"/>
The solutions for
<img src="https://latex.codecogs.com/svg.image?\inline&space;\theta_{dk}\propto\sum_wn_{dw}\phi_{dwk}"/>
and
<img src="https://latex.codecogs.com/svg.image?\inline&space;\beta_{kw}\propto\sum_dn_{dw}\phi_{dwk}"/>
follow from the maximization of cross-entropy loss.
The solution for
<img src="https://latex.codecogs.com/svg.image?\inline&space;\phi_{dwk}\propto&space;\theta_{dk}\beta_{kw}"/>
follows from Kullback-Leibler divergence.
After normalizing to
<img src="https://latex.codecogs.com/svg.image?\inline&space;\sum_k\phi_{dwk}=1"/>,
the difference
<img src="https://latex.codecogs.com/svg.image?\inline&space;\ell_{dw}=\log\beta_{kw}+\log\theta_{dk}-\log\phi_{dwk}"/>
becomes constant in k,
which is connected to the likelihood for the observed document-word pairs.
Note that after learning, the document vector θ_d considers the correlation between all words in d and similarly the topic distribution vector β_k considers the correlations in all observed documents.
**Variational Bayes** **Variational Bayes**
A Bayesian model adds Dirichlet priors to θ_d & β_z. This causes the posterior to be implicit and the bound to be loose. We will still use an independence assumption and cycle through the variational parameters similarly to coordinate ascent. A Bayesian model adds Dirichlet priors to θ_d and β_z, which leads to a similar ELBO if we assume independence
<img src="https://latex.codecogs.com/svg.image?\inline&space;q(z,\theta,\beta;\phi,\gamma,\lambda)=q(z;\phi)q(\theta;\gamma)q(\beta;\lambda)"/>,
* The evidence lower-bound is i.e.:
<img src="https://latex.codecogs.com/gif.latex?\log&space;p(G)\geq&space;\mathbb{E}_q\left[\sum_{(d,w)}\log\left(&space;\frac{\theta_{dz}\beta_{zw}}{q(z;\phi_{dw})}&space;\right)&space;&plus;\sum_{d}&space;\log\left(&space;\frac{p(\theta_d;\alpha)}{q(\theta_d;\gamma_d)}&space;\right)&space;&plus;\sum_{z}&space;\log\left(&space;\frac{p(\beta_z;\eta)}{q(\beta_z;\lambda_z)}&space;\right)\right]" title="elbo" />
<!--
* ELBO objective function factors as \log p(w;\alpha,\eta) \geq \mathcal{L}(w,\phi,\gamma,\lambda)
<img src="https://latex.codecogs.com/gif.latex?\sum_{(d,w)}&space;\phi_{dw}^{\top}\left(&space;\mathbb{E}_{\gamma_d}[\log\theta_d]&space;&plus;\mathbb{E}_{\lambda}[\log\beta_{:w}]&space;-\log\phi_{dw}&space;\right)&space;\\&space;&plus;&space;\sum_d&space;(\alpha-\gamma_d)^\top\mathbb{E}_{\gamma_d}[\log&space;\theta_d]-(\log&space;B(\alpha)-\log&space;B(\gamma_d))&space;\\&space;&plus;&space;\sum_z&space;(\eta-\lambda_z)^\top\mathbb{E}_{\lambda_z}[\log&space;\beta_z]-(\log&space;B(\eta)-\log&space;B(\lambda_z))" title="factors" /> \stackrel{def}{=}
\mathbb{E}_q [\log p(w,z,\theta,\beta;\alpha,\eta) - \log q(z,\theta,\beta;\phi,\gamma,\lambda)]
* Similarly, optimization alternates between ϕ, γ, λ. Since θ, β are random, we use an explicit solution for E[log X] under Dirichlet distribution via digamma function. \\=
\mathbb{E}_q \left[
\log p(w|z,\beta) + \log p(z|\theta) - \log q(z;\phi)
+\log p(\theta;\alpha) - \log q(\theta;\gamma)
+\log p(\beta;\eta) - \log q(\beta;\lambda)
\right]
\\=
\sum_{dwk}n_{dw}\phi_{dwk} (\mathbb{E}_{\lambda_k}[\log\beta_{kw}] + \mathbb{E}_{\gamma_d}[\log \theta_{dk}] - \log \phi_{dwk})
\\+\sum_{d}\left[
(\alpha-\gamma_d)^\top\mathbb{E}_{\gamma_d}[\log\theta_d]
-(\log B(\alpha 1_K) - \log B(\gamma_d))
\right]
\\+\sum_{k}\left[
(\eta-\lambda_k)^\top\mathbb{E}_{\lambda_k}[\log\beta_k]
-(\log B(\eta 1_W) - \log B(\lambda_k))
\right]
-->
<img src="https://latex.codecogs.com/svg.image?\log&space;p(w;\alpha,\eta)&space;\geq&space;\mathcal{L}(w,\phi,\gamma,\lambda)\stackrel{def}{=}\mathbb{E}_q&space;[\log&space;p(w,z,\theta,\beta;\alpha,\eta)&space;-&space;\log&space;q(z,\theta,\beta;\phi,\gamma,\lambda)]\\=\mathbb{E}_q&space;\left[\log&space;p(w|z,\beta)&space;&plus;&space;\log&space;p(z|\theta)&space;-&space;\log&space;q(z;\phi)&plus;\log&space;p(\theta;\alpha)&space;-&space;\log&space;q(\theta;\gamma)&plus;\log&space;p(\beta;\eta)&space;-&space;\log&space;q(\beta;\lambda)\right]\\=\sum_{dwk}n_{dw}\phi_{dwk}&space;(\mathbb{E}_{\lambda_k}[\log\beta_{kw}]&space;&plus;&space;\mathbb{E}_{\gamma_d}[\log&space;\theta_{dk}]&space;-&space;\log&space;\phi_{dwk})\\&plus;\sum_{d}\left[(\alpha-\gamma_d)^\top\mathbb{E}_{\gamma_d}[\log\theta_d]-(\log&space;B(\alpha&space;1_K)&space;-&space;\log&space;B(\gamma_d))\right]\\&plus;\sum_{k}\left[(\eta-\lambda_k)^\top\mathbb{E}_{\lambda_k}[\log\beta_k]-(\log&space;B(\eta&space;1_W)&space;-&space;\log&space;B(\lambda_k))\right]"/>
**Solutions**
The solutions to VB subsumes the solutions to multinomial PCA when n goes to infinity.
The solution for ϕ is
<img src="https://latex.codecogs.com/svg.image?\inline&space;\log\phi_{dwk}=\mathbb{E}_{\gamma_d}[\log\theta_{dk}]+\mathbb{E}_{\lambda_k}[\log\beta_{kw}]-\ell_{dw}"/>,
where the additional expectation can be expressed via digamma functions
and
<img src="https://latex.codecogs.com/svg.image?\inline&space;\ell_{dw}=\log\sum_k\exp(\mathbb{E}_{\gamma_d}[\log\theta_{dk}]+\mathbb{E}_{\lambda_k}[\log\beta_{kw}])"/>
is the log-partition function.
The solutions for
<img src="https://latex.codecogs.com/svg.image?\inline&space;\gamma_{dk}=\alpha+\sum_wn_{dw}\phi_{dwk}"/>
and
<img src="https://latex.codecogs.com/svg.image?\inline&space;\lambda_{kw}=\eta+\sum_dn_{dw}\phi_{dwk}"/>
come from direct gradient calculation.
After substituting the optimal solutions, we compute the marginal likelihood by adding the three terms, which are all connected to (the negative of) Kullback-Leibler divergence.
DGL usage DGL usage
--- ---
The corpus is represented as a bipartite multi-graph G. The corpus is represented as a bipartite multi-graph G.
We use DGL to propagate information through the edges and aggregate the distributions at doc/word nodes. We use DGL to propagate information through the edges and aggregate the distributions at doc/word nodes.
For scalability, the phi variables are transient and updated during message passing. For scalability, the phi variables are transient and updated during message passing.
The gamma / lambda variables are updated after the nodes receive all edge messages. The gamma / lambda variables are updated after the nodes receive all edge messages.
Following the conventions in [1], the gamma update is called E-step and the lambda update is called M-step, because the beta variable has smaller variance. Following the conventions in [1], the gamma update is called E-step and the lambda update is called M-step.
The lambda variable is further recorded by the trainer and we may further approximate its MAP estimate by using a large step size for word nodes. The lambda variable is further recorded by the trainer.
A separate function is used to produce perplexity, which is based on the ELBO objective function divided by the total numbers of word/doc occurrences. A separate function is used to produce perplexity, which is based on the ELBO objective function divided by the total numbers of word/doc occurrences.
Example Example
--- ---
`%run example_20newsgroups.py` `%run example_20newsgroups.py`
* Approximately matches scikit-learn training perplexity after 10 rounds of training. * Approximately matches scikit-learn training perplexity after 10 rounds of training.
* Exactly matches scikit-learn training perplexity if word_z is set to lda.components_.T * Exactly matches scikit-learn training perplexity if word_z is set to lda.components_.T
* To compute testing perplexity, we need to fix the word beta variables via MAP estimate. This step is not taken by sklearn and its beta part seems to contain another bug by dividing the training loss by the testing word counts. Nonetheless, I recommend setting `step_size["word"]` to a larger value to approximate the corresponding MAP estimate. * There is a difference in how we compute testing perplexity. We weigh the beta contributions by the training word counts, whereas sklearn weighs them by test word counts.
* The DGL-LDA model runs 50% faster on GPU devices compared with sklearn without joblib parallel. * The DGL-LDA model runs 50% faster on GPU devices compared with sklearn without joblib parallel.
Advanced configurations Advanced configurations
--- ---
* Set `step_size["word"]` to a large value obtain a MAP estimate for beta. * Set `0<rho<1` for online learning with partial_fit.
* Set `0<word_rho<1` for online learning. * Set `mult["doc"]=100` or `mult["word"]=100` or some large value to disable the corresponding Bayesian priors.
References References
--- ---
......
...@@ -103,7 +103,7 @@ print() ...@@ -103,7 +103,7 @@ print()
print("Training dgl-lda model...") print("Training dgl-lda model...")
t0 = time() t0 = time()
model = LDAModel(G, n_components) model = LDAModel(G.num_nodes('word'), n_components)
model.fit(G) model.fit(G)
print("done in %0.3fs." % (time() - t0)) print("done in %0.3fs." % (time() - t0))
print() print()
...@@ -111,8 +111,9 @@ print() ...@@ -111,8 +111,9 @@ print()
print(f"dgl-lda training perplexity {model.perplexity(G):.3f}") print(f"dgl-lda training perplexity {model.perplexity(G):.3f}")
print(f"dgl-lda testing perplexity {model.perplexity(Gt):.3f}") print(f"dgl-lda testing perplexity {model.perplexity(Gt):.3f}")
word_nphi = np.vstack([nphi.tolist() for nphi in model.word_data.nphi])
plot_top_words( plot_top_words(
type('dummy', (object,), {'components_': G.ndata['z']['word'].cpu().numpy().T}), type('dummy', (object,), {'components_': word_nphi}),
tf_feature_names, n_top_words, 'Topics in LDA model') tf_feature_names, n_top_words, 'Topics in LDA model')
print("Training scikit-learn model...") print("Training scikit-learn model...")
......
...@@ -17,77 +17,200 @@ ...@@ -17,77 +17,200 @@
# limitations under the License. # limitations under the License.
import os, functools, warnings import os, functools, warnings, torch, collections, dgl, io
import numpy as np, scipy as sp import numpy as np, scipy as sp
import torch
import dgl
from dgl import function as fn
try:
from functools import cached_property
except ImportError:
try:
from backports.cached_property import cached_property
except ImportError:
warnings.warn("cached_property not found - using property instead")
cached_property = property
def _lbeta(alpha, axis):
return torch.lgamma(alpha).sum(axis) - torch.lgamma(alpha.sum(axis))
# Taken from scikit-learn. Worked better than uniform. class EdgeData:
# Perhaps this is due to concentration around one. def __init__(self, src_data, dst_data):
_sklearn_random_init = torch.distributions.gamma.Gamma(100, 100) self.src_data = src_data
self.dst_data = dst_data
@property
def loglike(self):
return (self.src_data['Elog'] + self.dst_data['Elog']).logsumexp(1)
def _edge_update(edges, step_size=1): @property
""" the Gibbs posterior distribution of z propto theta*beta. def phi(self):
As step_size -> infty, the result becomes MAP estimate on the dst nodes. return (
""" self.src_data['Elog'] + self.dst_data['Elog'] - self.loglike.unsqueeze(1)
q = edges.src['weight'] * edges.dst['weight'] ).exp()
marg = q.sum(axis=1, keepdims=True) + np.finfo(float).eps
p = q / marg
return { @property
'z': p * step_size, def expectation(self):
'edge_elbo': marg.squeeze(1).log() * step_size, return (self.src_data['expectation'] * self.dst_data['expectation']).sum(1)
}
def _weight_exp(z, ntype, prior): class _Dirichlet:
"""Node weight is approximately normalized for VB along the ntype def __init__(self, prior, nphi, _chunksize=int(1e6)):
direction. self.prior = prior
""" self.nphi = nphi
prior = prior + z * 0 # convert numpy to torch self.device = nphi.device
gamma = prior + z self._sum_by_parts = lambda map_fn: functools.reduce(torch.add, [
map_fn(slice(i, min(i+_chunksize, nphi.shape[1]))).sum(1)
for i in list(range(0, nphi.shape[1], _chunksize))
])
def _posterior(self, _ID=slice(None)):
return self.prior + self.nphi[:, _ID]
@cached_property
def posterior_sum(self):
return self.nphi.sum(1) + self.prior * self.nphi.shape[1]
def _Elog(self, _ID=slice(None)):
return torch.digamma(self._posterior(_ID)) - \
torch.digamma(self.posterior_sum.unsqueeze(1))
@cached_property
def loglike(self):
neg_evid = -self._sum_by_parts(
lambda s: (self.nphi[:, s] * self._Elog(s))
)
prior = torch.as_tensor(self.prior).to(self.nphi)
K = self.nphi.shape[1]
log_B_prior = torch.lgamma(prior) * K - torch.lgamma(prior * K)
log_B_posterior = self._sum_by_parts(
lambda s: torch.lgamma(self._posterior(s))
) - torch.lgamma(self.posterior_sum)
return neg_evid - log_B_prior + log_B_posterior
@cached_property
def n(self):
return self.nphi.sum(1)
@cached_property
def cdf(self):
cdf = self._posterior()
torch.cumsum(cdf, 1, out=cdf)
cdf /= cdf[:, -1:].clone()
return cdf
def _expectation(self, _ID=slice(None)):
expectation = self._posterior(_ID)
expectation /= self.posterior_sum.unsqueeze(1)
return expectation
@cached_property
def Bayesian_gap(self):
return 1. - self._sum_by_parts(lambda s: self._Elog(s).exp())
_cached_properties = ["posterior_sum", "loglike", "n", "cdf", "Bayesian_gap"]
def clear_cache(self):
for name in self._cached_properties:
try:
delattr(self, name)
except AttributeError:
pass
def update(self, new, _ID=slice(None), rho=1):
""" inplace: old * (1-rho) + new * rho """
self.clear_cache()
mean_change = (self.nphi[:, _ID] - new).abs().mean().tolist()
self.nphi *= (1 - rho)
self.nphi[:, _ID] += new * rho
return mean_change
class DocData(_Dirichlet):
""" nphi (n_docs by n_topics) """
def prepare_graph(self, G, key="Elog"):
G.nodes['doc'].data[key] = getattr(self, '_'+key)().to(G.device)
def update_from(self, G, mult):
new = G.nodes['doc'].data['nphi'] * mult
return self.update(new.to(self.device))
class _Distributed(collections.UserList):
""" split on dim=0 and store on multiple devices """
def __init__(self, prior, nphi):
self.prior = prior
self.nphi = nphi
super().__init__([_Dirichlet(self.prior, nphi) for nphi in self.nphi])
axis = 1 if ntype == 'doc' else 0 # word def split_device(self, other, dim=0):
Elog = torch.digamma(gamma) - torch.digamma(gamma.sum(axis, keepdims=True)) split_sections = [x.shape[0] for x in self.nphi]
out = torch.split(other, split_sections, dim)
return [y.to(x.device) for x,y in zip(self.nphi, out)]
return Elog.exp()
class WordData(_Distributed):
""" distributed nphi (n_topics by n_words), transpose to/from graph nodes data """
def prepare_graph(self, G, key="Elog"):
if '_ID' in G.nodes['word'].data:
_ID = G.nodes['word'].data['_ID']
else:
_ID = slice(None)
def _node_update(nodes, prior): out = [getattr(part, '_'+key)(_ID).to(G.device) for part in self]
return { G.nodes['word'].data[key] = torch.cat(out).T
'z': nodes.data['z'],
'weight': _weight_exp(nodes.data['z'], nodes.ntype, prior)
}
def _update_all(G, ntype, prior, step_size, return_obj=False): def update_from(self, G, mult, rho):
"""Follows Eq (5) of Hoffman et al., 2010. nphi = G.nodes['word'].data['nphi'].T * mult
"""
G_prop = G.reverse() if ntype == 'doc' else G # word if '_ID' in G.nodes['word'].data:
msg_fn = lambda edges: _edge_update(edges, step_size) _ID = G.nodes['word'].data['_ID']
node_fn = lambda nodes: _node_update(nodes, prior) else:
_ID = slice(None)
G_prop.update_all(msg_fn, fn.sum('z','z'), node_fn) mean_change = [x.update(y, _ID, rho)
for x, y in zip(self, self.split_device(nphi))]
return np.mean(mean_change)
if return_obj:
G_prop.update_all(msg_fn, fn.sum('edge_elbo', 'edge_elbo'))
G.nodes[ntype].data.update(G_prop.nodes[ntype].data) class Gamma(collections.namedtuple('Gamma', "concentration, rate")):
return dict(G.nodes[ntype].data) """ articulate the difference between torch gamma and numpy gamma """
@property
def shape(self):
return self.concentration
@property
def scale(self):
return 1 / self.rate
def sample(self, shape, device):
return torch.distributions.gamma.Gamma(
torch.as_tensor(self.concentration, device=device),
torch.as_tensor(self.rate, device=device),
).sample(shape)
class LatentDirichletAllocation: class LatentDirichletAllocation:
"""LDA model that works with a HeteroGraph with doc/word node types. """LDA model that works with a HeteroGraph with doc->word meta paths.
The model alters the attributes of G arbitrarily, The model alters the attributes of G arbitrarily.
but always load word_z if needed.
This is inspired by [1] and its corresponding scikit-learn implementation. This is inspired by [1] and its corresponding scikit-learn implementation.
Inputs
---
* G: a template graph or an integer showing n_words
* n_components: latent feature dimension; automatically set priors if missing.
* prior: parameters in the Dirichlet prior; default to 1/n_components and 1/n_words
* rho: new_nphi = (1-rho)*old_nphi + rho*nphi; default to 1 for full gradients.
* mult: multiplier for nphi-update; a large value effectively disables prior.
* init: sklearn initializers (100.0, 100.0); the sample points concentrate around 1.0
* device_list: accelerate word_data updates.
Notes
---
Some differences between this and sklearn.decomposition.LatentDirichletAllocation:
* default word perplexity is normalized by training set instead of testing set.
References References
--- ---
[1] Matthew Hoffman, Francis Bach, David Blei. Online Learning for Latent [1] Matthew Hoffman, Francis Bach, David Blei. Online Learning for Latent
...@@ -96,135 +219,239 @@ class LatentDirichletAllocation: ...@@ -96,135 +219,239 @@ class LatentDirichletAllocation:
[2] Reactive LDA Library blogpost by Yingjie Miao for a similar Gibbs model [2] Reactive LDA Library blogpost by Yingjie Miao for a similar Gibbs model
""" """
def __init__( def __init__(
self, G, n_components, self, n_words, n_components,
prior=None, prior=None,
step_size={'doc': 1, 'word': 1}, # use larger value to get MAP rho=1,
word_rho=1, # use smaller value for online update mult={'doc': 1, 'word': 1},
init={'doc': (100., 100.), 'word': (100., 100.)},
device_list=['cpu'],
verbose=True, verbose=True,
): ):
self.n_words = n_words
self.n_components = n_components self.n_components = n_components
if prior is None: if prior is None:
prior = {'doc': 1./n_components, 'word': 1./n_components} prior = {'doc': 1./n_components, 'word': 1./n_components}
self.prior = prior self.prior = prior
self.step_size = step_size self.rho = rho
self.mult = mult
self.init = init
self.word_rho = word_rho assert not isinstance(device_list, str), "plz wrap devices in a list"
self.word_z = self._load_or_init(G, 'word')['z'] self.device_list = device_list[:n_components] # avoid edge cases
self.verbose = verbose self.verbose = verbose
def _load_or_init(self, G, ntype, z=None): self._init_word_data()
if z is None:
z = _sklearn_random_init.sample(
(G.num_nodes(ntype), self.n_components) def _init_word_data(self):
).to(G.device) split_sections = np.diff(
np.linspace(0, self.n_components, len(self.device_list)+1).astype(int)
)
word_nphi = [
Gamma(*self.init['word']).sample((s, self.n_words), device)
for s, device in zip(split_sections, self.device_list)
]
self.word_data = WordData(self.prior['word'], word_nphi)
def _init_doc_data(self, n_docs, device):
doc_nphi = Gamma(*self.init['doc']).sample(
(n_docs, self.n_components), device)
return DocData(self.prior['doc'], doc_nphi)
def save(self, f):
for w in self.word_data:
w.clear_cache()
torch.save({
'prior': self.prior,
'rho': self.rho,
'mult': self.mult,
'init': self.init,
'word_data': [part.nphi for part in self.word_data],
}, f)
G.nodes[ntype].data['z'] = z
G.apply_nodes(
lambda nodes: _node_update(nodes, self.prior[ntype]),
ntype=ntype)
return dict(G.nodes[ntype].data)
def _e_step(self, G, reinit_doc=True, mean_change_tol=1e-3, max_iters=100): def _prepare_graph(self, G, doc_data, key="Elog"):
doc_data.prepare_graph(G, key)
self.word_data.prepare_graph(G, key)
def _e_step(self, G, doc_data=None, mean_change_tol=1e-3, max_iters=100):
"""_e_step implements doc data sampling until convergence or max_iters """_e_step implements doc data sampling until convergence or max_iters
""" """
self._load_or_init(G, 'word', self.word_z) if doc_data is None:
doc_data = self._init_doc_data(G.num_nodes('doc'), G.device)
if reinit_doc or ('weight' not in G.nodes['doc'].data): G_rev = G.reverse() # word -> doc
self._load_or_init(G, 'doc') self.word_data.prepare_graph(G_rev)
for i in range(max_iters): for i in range(max_iters):
doc_z = dict(G.nodes['doc'].data)['z'] doc_data.prepare_graph(G_rev)
doc_data = _update_all( G_rev.update_all(
G, 'doc', self.prior['doc'], self.step_size['doc']) lambda edges: {'phi': EdgeData(edges.src, edges.dst).phi},
mean_change = (doc_data['z'] - doc_z).abs().mean() dgl.function.sum('phi', 'nphi')
)
mean_change = doc_data.update_from(G_rev, self.mult['doc'])
if mean_change < mean_change_tol: if mean_change < mean_change_tol:
break break
if self.verbose: if self.verbose:
print(f'e-step num_iters={i+1} with mean_change={mean_change:.4f}') print(f"e-step num_iters={i+1} with mean_change={mean_change:.4f}, "
f"perplexity={self.perplexity(G, doc_data):.4f}")
return doc_data return doc_data
transform = _e_step transform = _e_step
def _m_step(self, G):
"""_m_step implements word data sampling and stores word_z stats def predict(self, doc_data):
pred_scores = [
# d_exp @ w._expectation()
(lambda x: x @ w.nphi + x.sum(1, keepdims=True) * w.prior)
(d_exp / w.posterior_sum.unsqueeze(0))
for (d_exp, w) in zip(
self.word_data.split_device(doc_data._expectation(), dim=1),
self.word_data)
]
x = torch.zeros_like(pred_scores[0], device=doc_data.device)
for p in pred_scores:
x += p.to(x.device)
return x
def sample(self, doc_data, num_samples):
""" draw independent words and return the marginal probabilities,
i.e., the expectations in Dirichlet distributions.
""" """
# assume G.nodes['doc'].data has been up to date def fn(cdf):
word_data = _update_all( u = torch.rand(cdf.shape[0], num_samples, device=cdf.device)
G, 'word', self.prior['word'], self.step_size['word']) return torch.searchsorted(cdf, u).to(doc_data.device)
# online update topic_ids = fn(doc_data.cdf)
self.word_z = ( word_ids = torch.cat([fn(part.cdf) for part in self.word_data])
(1-self.word_rho) * self.word_z ids = torch.gather(word_ids, 0, topic_ids) # pick components by topic_ids
+self.word_rho * word_data['z']
# compute expectation scores on sampled ids
src_ids = torch.arange(
ids.shape[0], dtype=ids.dtype, device=ids.device
).reshape((-1, 1)).expand(ids.shape)
unique_ids, inverse_ids = torch.unique(ids, sorted=False, return_inverse=True)
G = dgl.heterograph({('doc','','word'): (src_ids.ravel(), inverse_ids.ravel())})
G.nodes['word'].data['_ID'] = unique_ids
self._prepare_graph(G, doc_data, "expectation")
G.apply_edges(lambda e: {'expectation': EdgeData(e.src, e.dst).expectation})
expectation = G.edata.pop('expectation').reshape(ids.shape)
return ids, expectation
def _m_step(self, G, doc_data):
"""_m_step implements word data sampling and stores word_z stats.
mean_change is in the sense of full graph with rho=1.
"""
G = G.clone()
self._prepare_graph(G, doc_data)
G.update_all(
lambda edges: {'phi': EdgeData(edges.src, edges.dst).phi},
dgl.function.sum('phi', 'nphi')
) )
return word_data self._last_mean_change = self.word_data.update_from(
G, self.mult['word'], self.rho)
if self.verbose:
print(f"m-step mean_change={self._last_mean_change:.4f}, ", end="")
Bayesian_gap = np.mean([
part.Bayesian_gap.mean().tolist() for part in self.word_data
])
print(f"Bayesian_gap={Bayesian_gap:.4f}")
def partial_fit(self, G): def partial_fit(self, G):
self._last_word_z = self.word_z doc_data = self._e_step(G)
self._e_step(G) self._m_step(G, doc_data)
self._m_step(G)
return self return self
def fit(self, G, mean_change_tol=1e-3, max_epochs=10): def fit(self, G, mean_change_tol=1e-3, max_epochs=10):
for i in range(max_epochs): for i in range(max_epochs):
self.partial_fit(G)
mean_change = (self.word_z - self._last_word_z).abs().mean()
if self.verbose: if self.verbose:
print(f'epoch {i+1}, ' print(f"epoch {i+1}, ", end="")
f'perplexity: {self.perplexity(G, False)}, ' self.partial_fit(G)
f'mean_change: {mean_change:.4f}')
if mean_change < mean_change_tol: if self._last_mean_change < mean_change_tol:
break break
return self return self
def perplexity(self, G, reinit_doc=True):
def perplexity(self, G, doc_data=None):
"""ppl = exp{-sum[log(p(w1,...,wn|d))] / n} """ppl = exp{-sum[log(p(w1,...,wn|d))] / n}
Follows Eq (15) in Hoffman et al., 2010. Follows Eq (15) in Hoffman et al., 2010.
""" """
word_data = self._load_or_init(G, 'word', self.word_z) if doc_data is None:
if reinit_doc or ('weight' not in G.nodes['doc'].data): doc_data = self._e_step(G)
self._e_step(G, reinit_doc)
doc_data = _update_all(
G, 'doc', self.prior['doc'], self.step_size['doc'],
return_obj=True)
# compute E[log p(docs | theta, beta)] # compute E[log p(docs | theta, beta)]
edge_elbo = ( G = G.clone()
doc_data['edge_elbo'].sum() / doc_data['z'].sum() self._prepare_graph(G, doc_data)
).cpu().numpy() G.apply_edges(lambda edges: {'loglike': EdgeData(edges.src, edges.dst).loglike})
edge_elbo = (G.edata['loglike'].sum() / G.num_edges()).tolist()
if self.verbose: if self.verbose:
print(f'neg_elbo phi: {-edge_elbo:.3f}', end=' ') print(f'neg_elbo phi: {-edge_elbo:.3f}', end=' ')
# compute E[log p(theta | alpha) - log q(theta | gamma)] # compute E[log p(theta | alpha) - log q(theta | gamma)]
doc_elbo = ( doc_elbo = (doc_data.loglike.sum() / doc_data.n.sum()).tolist()
(-doc_data['z'] * doc_data['weight'].log()).sum(axis=1)
-_lbeta(self.prior['doc'] + doc_data['z'] * 0, axis=1)
+_lbeta(self.prior['doc'] + doc_data['z'], axis=1)
)
doc_elbo = (doc_elbo.sum() / doc_data['z'].sum()).cpu().numpy()
if self.verbose: if self.verbose:
print(f'theta: {-doc_elbo:.3f}', end=' ') print(f'theta: {-doc_elbo:.3f}', end=' ')
# compute E[log p(beta | eta) - log q (beta | lambda)] # compute E[log p(beta | eta) - log q(beta | lambda)]
# The denominator n for extrapolation perplexity is undefined.
# We use the train set, whereas sklearn uses the test set.
word_elbo = ( word_elbo = (
(-word_data['z'] * word_data['weight'].log()).sum(axis=0) sum([part.loglike.sum().tolist() for part in self.word_data])
-_lbeta(self.prior['word'] + word_data['z'] * 0, axis=0) / sum([part.n.sum().tolist() for part in self.word_data])
+_lbeta(self.prior['word'] + word_data['z'], axis=0) )
)
word_elbo = (word_elbo.sum() / word_data['z'].sum()).cpu().numpy()
if self.verbose: if self.verbose:
print(f'beta: {-word_elbo:.3f}') print(f'beta: {-word_elbo:.3f}')
return np.exp(-edge_elbo - doc_elbo - word_elbo) ppl = np.exp(-edge_elbo - doc_elbo - word_elbo)
if G.num_edges()>0 and np.isnan(ppl):
warnings.warn("numerical issue in perplexity")
return ppl
def doc_subgraph(G, doc_ids):
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
block, *_ = sampler.sample_blocks(G.reverse(), {'doc': torch.as_tensor(doc_ids)})
B = dgl.DGLHeteroGraph(
block._graph, ['_', 'word', 'doc', '_'], block.etypes
).reverse()
B.nodes['word'].data['_ID'] = block.nodes['word'].data['_ID']
return B
if __name__ == '__main__': if __name__ == '__main__':
print('Testing LatentDirichletAllocation via task_example_test.sh ...') print('Testing LatentDirichletAllocation ...')
tf_uv = np.array(np.nonzero(np.random.rand(20,10)<0.5)).T G = dgl.heterograph({('doc', '', 'word'): [(0, 0), (1, 3)]}, {'doc': 2, 'word': 5})
G = dgl.heterograph({('doc','topic','word'): tf_uv.tolist()}) model = LatentDirichletAllocation(n_words=5, n_components=10, verbose=False)
model = LatentDirichletAllocation(G, 5, verbose=False)
model.fit(G) model.fit(G)
model.transform(G) model.transform(G)
model.predict(model.transform(G))
if hasattr(torch, "searchsorted"):
model.sample(model.transform(G), 3)
model.perplexity(G) model.perplexity(G)
for doc_id in range(2):
B = doc_subgraph(G, [doc_id])
model.partial_fit(B)
with io.BytesIO() as f:
model.save(f)
f.seek(0)
print(torch.load(f))
print('Testing LatentDirichletAllocation passed!')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment