Unverified Commit c0184365 authored by yifeim's avatar yifeim Committed by GitHub
Browse files

[Example] add latent dirichlet allocation (#2883)



* add lda model

* tweak latent dirichlet allocation

* Update README.md

* Update README.md

* update example index

* update header

* minor tweak

* add example test

* update doc

* Update README.md

* Update README.md

* add partial_fit for free

* Update examples/pytorch/lda/lda_model.py
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>

* Update examples/pytorch/lda/example_20newsgroups.py
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>

* Update lda_model.py

* bugfix torch Gamma uses rate parameter
Co-authored-by: default avatarYifei Ma <yifeim@amazon.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 657c220d
...@@ -8,6 +8,7 @@ The folder contains example implementations of selected research papers related ...@@ -8,6 +8,7 @@ The folder contains example implementations of selected research papers related
| Paper | node classification | link prediction / classification | graph property prediction | sampling | OGB | | Paper | node classification | link prediction / classification | graph property prediction | sampling | OGB |
| ------------------------------------------------------------ | ------------------- | -------------------------------- | ------------------------- | ------------------ | ------------------ | | ------------------------------------------------------------ | ------------------- | -------------------------------- | ------------------------- | ------------------ | ------------------ |
| [Latent Dirichlet Allocation](#lda) | :heavy_check_mark: | :heavy_check_mark: | | | |
| [Network Embedding with Completely-imbalanced Labels](#rect) | :heavy_check_mark: | | | | | | [Network Embedding with Completely-imbalanced Labels](#rect) | :heavy_check_mark: | | | | |
| [Boost then Convolve: Gradient Boosting Meets Graph Neural Networks](#bgnn) | :heavy_check_mark: | | | | | | [Boost then Convolve: Gradient Boosting Meets Graph Neural Networks](#bgnn) | :heavy_check_mark: | | | | |
| [Contrastive Multi-View Representation Learning on Graphs](#mvgrl) | :heavy_check_mark: | | :heavy_check_mark: | | | | [Contrastive Multi-View Representation Learning on Graphs](#mvgrl) | :heavy_check_mark: | | :heavy_check_mark: | | |
...@@ -410,6 +411,12 @@ The folder contains example implementations of selected research papers related ...@@ -410,6 +411,12 @@ The folder contains example implementations of selected research papers related
- Example code: [PyTorch](https://github.com/awslabs/dgl-ke/tree/master/examples), [PyTorch for custom data](https://aws-dglke.readthedocs.io/en/latest/commands.html) - Example code: [PyTorch](https://github.com/awslabs/dgl-ke/tree/master/examples), [PyTorch for custom data](https://aws-dglke.readthedocs.io/en/latest/commands.html)
- Tags: knowledge graph embedding - Tags: knowledge graph embedding
## 2010
- <a name="lda"></a> Hoffman et al. Online Learning for Latent Dirichlet Allocation. [Paper link](https://papers.nips.cc/paper/2010/file/71f6278d140af599e06ad9bf1ba03cb0-Paper.pdf).
- Example code: [PyTorch](../examples/pytorch/lda)
- Tags: sklearn, decomposition, latent Dirichlet allocation
## 2009 ## 2009
- <a name="astar"></a> Riesen et al. Speeding Up Graph Edit Distance Computation with a Bipartite Heuristic. [Paper link](https://core.ac.uk/download/pdf/33054885.pdf). - <a name="astar"></a> Riesen et al. Speeding Up Graph Edit Distance Computation with a Bipartite Heuristic. [Paper link](https://core.ac.uk/download/pdf/33054885.pdf).
......
Latent Dirichlet Allocation
===
LDA is a classical algorithm for probabilistic graphical models. It assumes
hierarchical Bayes models with discrete variables on sparse doc/word graphs.
This example shows how it can be done on DGL,
where the corpus is represented as a bipartite multi-graph G.
There is no back-propagation, because gradient descent is typically considered
inefficient on probability simplex.
On the provided small-scale example on 20 news groups dataset, our DGL-LDA model runs
50% faster on GPU than sklearn model without joblib parallel.
Key equations
---
* The corpus is generated by hierarchical Bayes: document(d) -> latent topic(z) -> word(w)
* All positions in the same document have shared topic distribution θ_d~Dir(α)
* All positions of the same topic have shared word distribution β_z~Dir(η)
* The words in the same document / topic are correlated.
**MAP**
A simplified MAP model is just a non-conjugate model with an inner summation to integrate out the latent topic variable:
<img src="https://latex.codecogs.com/gif.latex?p(G)=\prod_{(d,w)}\left(\sum_z\theta_{dz}\beta_{zw}\right)" title="map" />
The main complications are that θ_d / β_z are shared in the same document / topic and the variables reside in a probability simplex.
One way to work around it is via expectation maximization
<img src="https://latex.codecogs.com/gif.latex?\log&space;p(G)&space;=\sum_{(d,w)}\log\left(\sum_z\theta_{dz}\beta_{zw}\right)&space;\geq\sum_{(d,w)}\mathbb{E}_q\log\left(\frac{\theta_{dz}\beta_{zw}}{q(z;\phi_{dw})}\right)" title="map-em" />
* An explicit posterior is ϕ_dwz ∝ θ_dz * β_zw
* E-step: find summary statistics with fractional membership
* M-step: set θ_d, β_z proportional to the summary statistics
* With an explicit posterior, the bound is tight.
**Variational Bayes**
A Bayesian model adds Dirichlet priors to θ_d & β_z. This causes the posterior to be implicit and the bound to be loose. We will still use an independence assumption and cycle through the variational parameters similarly to coordinate ascent.
* The evidence lower-bound is
<img src="https://latex.codecogs.com/gif.latex?\log&space;p(G)\geq&space;\mathbb{E}_q\left[\sum_{(d,w)}\log\left(&space;\frac{\theta_{dz}\beta_{zw}}{q(z;\phi_{dw})}&space;\right)&space;&plus;\sum_{d}&space;\log\left(&space;\frac{p(\theta_d;\alpha)}{q(\theta_d;\gamma_d)}&space;\right)&space;&plus;\sum_{z}&space;\log\left(&space;\frac{p(\beta_z;\eta)}{q(\beta_z;\lambda_z)}&space;\right)\right]" title="elbo" />
* ELBO objective function factors as
<img src="https://latex.codecogs.com/gif.latex?\sum_{(d,w)}&space;\phi_{dw}^{\top}\left(&space;\mathbb{E}_{\gamma_d}[\log\theta_d]&space;&plus;\mathbb{E}_{\lambda}[\log\beta_{:w}]&space;-\log\phi_{dw}&space;\right)&space;\\&space;&plus;&space;\sum_d&space;(\alpha-\gamma_d)^\top\mathbb{E}_{\gamma_d}[\log&space;\theta_d]-(\log&space;B(\alpha)-\log&space;B(\gamma_d))&space;\\&space;&plus;&space;\sum_z&space;(\eta-\lambda_z)^\top\mathbb{E}_{\lambda_z}[\log&space;\beta_z]-(\log&space;B(\eta)-\log&space;B(\lambda_z))" title="factors" />
* Similarly, optimization alternates between ϕ, γ, λ. Since θ, β are random, we use an explicit solution for E[log X] under Dirichlet distribution via digamma function.
DGL usage
---
The corpus is represented as a bipartite multi-graph G.
We use DGL to propagate information through the edges and aggregate the distributions at doc/word nodes.
For scalability, the phi variables are transient and updated during message passing.
The gamma / lambda variables are updated after the nodes receive all edge messages.
Following the conventions in [1], the gamma update is called E-step and the lambda update is called M-step, because the beta variable has smaller variance.
The lambda variable is further recorded by the trainer and we may further approximate its MAP estimate by using a large step size for word nodes.
A separate function is used to produce perplexity, which is based on the ELBO objective function divided by the total numbers of word/doc occurrences.
Example
---
`%run example_20newsgroups.py`
* Approximately matches scikit-learn training perplexity after 10 rounds of training.
* Exactly matches scikit-learn training perplexity if word_z is set to lda.components_.T
* To compute testing perplexity, we need to fix the word beta variables via MAP estimate. This step is not taken by sklearn and its beta part seems to contain another bug by dividing the training loss by the testing word counts. Nonetheless, I recommend setting `step_size["word"]` to a larger value to approximate the corresponding MAP estimate.
* The DGL-LDA model runs 50% faster on GPU devices compared with sklearn without joblib parallel.
Advanced configurations
---
* Set `step_size["word"]` to a large value obtain a MAP estimate for beta.
* Set `0<word_rho<1` for online learning.
References
---
1. Matthew Hoffman, Francis Bach, David Blei. Online Learning for Latent
Dirichlet Allocation. Advances in Neural Information Processing Systems 23
(NIPS 2010).
2. Reactive LDA Library blogpost by Yingjie Miao for a similar Gibbs model
# Copyright 2021 Yifei Ma
# Modified from scikit-learn example "plot_topics_extraction_with_nmf_lda.py"
# with the following original authors with BSD 3-Clause:
# * Olivier Grisel <olivier.grisel@ensta.org>
# * Lars Buitinck
# * Chyi-Kwei Yau <chyikwei.yau@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from time import time
import matplotlib.pyplot as plt
import warnings
import numpy as np
import scipy.sparse as ss
import torch
import dgl
from dgl import function as fn
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.decomposition import NMF, LatentDirichletAllocation
from sklearn.datasets import fetch_20newsgroups
from lda_model import LatentDirichletAllocation as LDAModel
n_samples = 2000
n_features = 1000
n_components = 10
n_top_words = 20
device = 'cuda'
def plot_top_words(model, feature_names, n_top_words, title):
fig, axes = plt.subplots(2, 5, figsize=(30, 15), sharex=True)
axes = axes.flatten()
for topic_idx, topic in enumerate(model.components_):
top_features_ind = topic.argsort()[:-n_top_words - 1:-1]
top_features = [feature_names[i] for i in top_features_ind]
weights = topic[top_features_ind]
ax = axes[topic_idx]
ax.barh(top_features, weights, height=0.7)
ax.set_title(f'Topic {topic_idx +1}',
fontdict={'fontsize': 30})
ax.invert_yaxis()
ax.tick_params(axis='both', which='major', labelsize=20)
for i in 'top right left'.split():
ax.spines[i].set_visible(False)
fig.suptitle(title, fontsize=40)
plt.subplots_adjust(top=0.90, bottom=0.05, wspace=0.90, hspace=0.3)
plt.show()
# Load the 20 newsgroups dataset and vectorize it. We use a few heuristics
# to filter out useless terms early on: the posts are stripped of headers,
# footers and quoted replies, and common English words, words occurring in
# only one document or in at least 95% of the documents are removed.
print("Loading dataset...")
t0 = time()
data, _ = fetch_20newsgroups(shuffle=True, random_state=1,
remove=('headers', 'footers', 'quotes'),
return_X_y=True)
data_samples = data[:n_samples]
data_test = data[n_samples:2*n_samples]
print("done in %0.3fs." % (time() - t0))
# Use tf (raw term count) features for LDA.
print("Extracting tf features for LDA...")
tf_vectorizer = CountVectorizer(max_df=0.95, min_df=2,
max_features=n_features,
stop_words='english')
t0 = time()
tf_vectorizer.fit(data)
tf = tf_vectorizer.transform(data_samples)
tt = tf_vectorizer.transform(data_test)
tf_feature_names = tf_vectorizer.get_feature_names()
tf_uv = [(u,v)
for u,v,e in zip(tf.tocoo().row, tf.tocoo().col, tf.tocoo().data)
for _ in range(e)]
tt_uv = [(u,v)
for u,v,e in zip(tt.tocoo().row, tt.tocoo().col, tt.tocoo().data)
for _ in range(e)]
print("done in %0.3fs." % (time() - t0))
print()
print("Preparing dgl graphs...")
t0 = time()
G = dgl.heterograph({('doc','topic','word'): tf_uv}, device=device)
Gt = dgl.heterograph({('doc','topic','word'): tt_uv}, device=device)
print("done in %0.3fs." % (time() - t0))
print()
print("Training dgl-lda model...")
t0 = time()
model = LDAModel(G, n_components)
model.fit(G)
print("done in %0.3fs." % (time() - t0))
print()
print(f"dgl-lda training perplexity {model.perplexity(G):.3f}")
print(f"dgl-lda testing perplexity {model.perplexity(Gt):.3f}")
plot_top_words(
type('dummy', (object,), {'components_': G.ndata['z']['word'].cpu().numpy().T}),
tf_feature_names, n_top_words, 'Topics in LDA model')
print("Training scikit-learn model...")
print('\n' * 2, "Fitting LDA models with tf features, "
"n_samples=%d and n_features=%d..."
% (n_samples, n_features))
lda = LatentDirichletAllocation(n_components=n_components, max_iter=5,
learning_method='online',
learning_offset=50.,
random_state=0,
verbose=1,
)
t0 = time()
lda.fit(tf)
print("done in %0.3fs." % (time() - t0))
print()
print(f"scikit-learn training perplexity {lda.perplexity(tf):.3f}")
print(f"scikit-learn testing perplexity {lda.perplexity(tt):.3f}")
# Copyright 2021 Yifei Ma
# with references from "sklearn.decomposition.LatentDirichletAllocation"
# with the following original authors:
# * Chyi-Kwei Yau (the said scikit-learn implementation)
# * Matthew D. Hoffman (original onlineldavb implementation)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os, functools, warnings
import numpy as np, scipy as sp
import torch
import dgl
from dgl import function as fn
def _lbeta(alpha, axis):
return torch.lgamma(alpha).sum(axis) - torch.lgamma(alpha.sum(axis))
# Taken from scikit-learn. Worked better than uniform.
# Perhaps this is due to concentration around one.
_sklearn_random_init = torch.distributions.gamma.Gamma(100, 100)
def _edge_update(edges, step_size=1):
""" the Gibbs posterior distribution of z propto theta*beta.
As step_size -> infty, the result becomes MAP estimate on the dst nodes.
"""
q = edges.src['weight'] * edges.dst['weight']
marg = q.sum(axis=1, keepdims=True) + np.finfo(float).eps
p = q / marg
return {
'z': p * step_size,
'edge_elbo': marg.squeeze(1).log() * step_size,
}
def _weight_exp(z, ntype, prior):
"""Node weight is approximately normalized for VB along the ntype
direction.
"""
prior = prior + z * 0 # convert numpy to torch
gamma = prior + z
axis = 1 if ntype == 'doc' else 0 # word
Elog = torch.digamma(gamma) - torch.digamma(gamma.sum(axis, keepdims=True))
return Elog.exp()
def _node_update(nodes, prior):
return {
'z': nodes.data['z'],
'weight': _weight_exp(nodes.data['z'], nodes.ntype, prior)
}
def _update_all(G, ntype, prior, step_size, return_obj=False):
"""Follows Eq (5) of Hoffman et al., 2010.
"""
G_prop = G.reverse() if ntype == 'doc' else G # word
msg_fn = lambda edges: _edge_update(edges, step_size)
node_fn = lambda nodes: _node_update(nodes, prior)
G_prop.update_all(msg_fn, fn.sum('z','z'), node_fn)
if return_obj:
G_prop.update_all(msg_fn, fn.sum('edge_elbo', 'edge_elbo'))
G.nodes[ntype].data.update(G_prop.nodes[ntype].data)
return dict(G.nodes[ntype].data)
class LatentDirichletAllocation:
"""LDA model that works with a HeteroGraph with doc/word node types.
The model alters the attributes of G arbitrarily,
but always load word_z if needed.
This is inspired by [1] and its corresponding scikit-learn implementation.
References
---
[1] Matthew Hoffman, Francis Bach, David Blei. Online Learning for Latent
Dirichlet Allocation. Advances in Neural Information Processing Systems 23
(NIPS 2010).
[2] Reactive LDA Library blogpost by Yingjie Miao for a similar Gibbs model
"""
def __init__(
self, G, n_components,
prior=None,
step_size={'doc': 1, 'word': 1}, # use larger value to get MAP
word_rho=1, # use smaller value for online update
verbose=True,
):
self.n_components = n_components
if prior is None:
prior = {'doc': 1./n_components, 'word': 1./n_components}
self.prior = prior
self.step_size = step_size
self.word_rho = word_rho
self.word_z = self._load_or_init(G, 'word')['z']
self.verbose = verbose
def _load_or_init(self, G, ntype, z=None):
if z is None:
z = _sklearn_random_init.sample(
(G.num_nodes(ntype), self.n_components)
).to(G.device)
G.nodes[ntype].data['z'] = z
G.apply_nodes(
lambda nodes: _node_update(nodes, self.prior[ntype]),
ntype=ntype)
return dict(G.nodes[ntype].data)
def _e_step(self, G, reinit_doc=True, mean_change_tol=1e-3, max_iters=100):
"""_e_step implements doc data sampling until convergence or max_iters
"""
self._load_or_init(G, 'word', self.word_z)
if reinit_doc or ('weight' not in G.nodes['doc'].data):
self._load_or_init(G, 'doc')
for i in range(max_iters):
doc_z = dict(G.nodes['doc'].data)['z']
doc_data = _update_all(
G, 'doc', self.prior['doc'], self.step_size['doc'])
mean_change = (doc_data['z'] - doc_z).abs().mean()
if mean_change < mean_change_tol:
break
if self.verbose:
print(f'e-step num_iters={i+1} with mean_change={mean_change:.4f}')
return doc_data
transform = _e_step
def _m_step(self, G):
"""_m_step implements word data sampling and stores word_z stats
"""
# assume G.nodes['doc'].data has been up to date
word_data = _update_all(
G, 'word', self.prior['word'], self.step_size['word'])
# online update
self.word_z = (
(1-self.word_rho) * self.word_z
+self.word_rho * word_data['z']
)
return word_data
def partial_fit(self, G):
self._last_word_z = self.word_z
self._e_step(G)
self._m_step(G)
return self
def fit(self, G, mean_change_tol=1e-3, max_epochs=10):
for i in range(max_epochs):
self.partial_fit(G)
mean_change = (self.word_z - self._last_word_z).abs().mean()
if self.verbose:
print(f'epoch {i+1}, '
f'perplexity: {self.perplexity(G, False)}, '
f'mean_change: {mean_change:.4f}')
if mean_change < mean_change_tol:
break
return self
def perplexity(self, G, reinit_doc=True):
"""ppl = exp{-sum[log(p(w1,...,wn|d))] / n}
Follows Eq (15) in Hoffman et al., 2010.
"""
word_data = self._load_or_init(G, 'word', self.word_z)
if reinit_doc or ('weight' not in G.nodes['doc'].data):
self._e_step(G, reinit_doc)
doc_data = _update_all(
G, 'doc', self.prior['doc'], self.step_size['doc'],
return_obj=True)
# compute E[log p(docs | theta, beta)]
edge_elbo = (
doc_data['edge_elbo'].sum() / doc_data['z'].sum()
).cpu().numpy()
if self.verbose:
print(f'neg_elbo phi: {-edge_elbo:.3f}', end=' ')
# compute E[log p(theta | alpha) - log q(theta | gamma)]
doc_elbo = (
(-doc_data['z'] * doc_data['weight'].log()).sum(axis=1)
-_lbeta(self.prior['doc'] + doc_data['z'] * 0, axis=1)
+_lbeta(self.prior['doc'] + doc_data['z'], axis=1)
)
doc_elbo = (doc_elbo.sum() / doc_data['z'].sum()).cpu().numpy()
if self.verbose:
print(f'theta: {-doc_elbo:.3f}', end=' ')
# compute E[log p(beta | eta) - log q (beta | lambda)]
word_elbo = (
(-word_data['z'] * word_data['weight'].log()).sum(axis=0)
-_lbeta(self.prior['word'] + word_data['z'] * 0, axis=0)
+_lbeta(self.prior['word'] + word_data['z'], axis=0)
)
word_elbo = (word_elbo.sum() / word_data['z'].sum()).cpu().numpy()
if self.verbose:
print(f'beta: {-word_elbo:.3f}')
return np.exp(-edge_elbo - doc_elbo - word_elbo)
if __name__ == '__main__':
print('Testing LatentDirichletAllocation via task_example_test.sh ...')
tf_uv = np.array(np.nonzero(np.random.rand(20,10)<0.5)).T
G = dgl.heterograph({('doc','topic','word'): tf_uv.tolist()})
model = LatentDirichletAllocation(G, 5, verbose=False)
model.fit(G)
model.transform(G)
model.perplexity(G)
...@@ -40,5 +40,6 @@ pushd $GCN_EXAMPLE_DIR> /dev/null ...@@ -40,5 +40,6 @@ pushd $GCN_EXAMPLE_DIR> /dev/null
python3 pagerank.py || fail "run pagerank.py on $1" python3 pagerank.py || fail "run pagerank.py on $1"
python3 gcn/gcn.py --dataset cora --gpu $dev || fail "run gcn/gcn.py on $1" python3 gcn/gcn.py --dataset cora --gpu $dev || fail "run gcn/gcn.py on $1"
python3 lda/lda_model.py || fail "run lda/lda_model.py on $1"
popd > /dev/null popd > /dev/null
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment