Commit 466d3fce authored by Philip Maybank's avatar Philip Maybank
Browse files

Extend Softmax section of Primitives Guide

* rename l to z

* add material on applying softmax row-wise to matrix

* define macro for diag operator (represents diagonal matrix)
parent 11edcb0e
...@@ -2,15 +2,12 @@ ...@@ -2,15 +2,12 @@
Supported Primitives Guide Supported Primitives Guide
========================== ==========================
------------
Introduction
------------
This document contains details of supported primitives in Composable Kernel (CK). In contrast to the API Reference This document contains details of supported primitives in Composable Kernel (CK). In contrast to the API Reference
Guide, the Supported Primitives Guide is an introduction to the math which underpins the algorithms implemented in CK. Guide, the Supported Primitives Guide is an introduction to the math which underpins the algorithms implemented in CK.
------------
Softmax Softmax
^^^^^^^ ------------
For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` we can decompose the softmax of concatenated For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` we can decompose the softmax of concatenated
:math:`x = [ x^{(1)}\ | \ \ldots \ | \ x^{(T)} ]` as, :math:`x = [ x^{(1)}\ | \ \ldots \ | \ x^{(T)} ]` as,
...@@ -21,9 +18,58 @@ For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` we can d ...@@ -21,9 +18,58 @@ For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` we can d
\begin{align} \begin{align}
m(x) & = m( [ x^{(1)}\ | \ \ldots \ | \ x^{(T)} ] ) = \max( m(x^{(1)}),\ldots, m(x^{(T)}) ) \\ m(x) & = m( [ x^{(1)}\ | \ \ldots \ | \ x^{(T)} ] ) = \max( m(x^{(1)}),\ldots, m(x^{(T)}) ) \\
f(x) & = [\exp( m(x^{(1)}) - m(x) ) f( x^{(1)} )\ | \ \ldots \ | \ \exp( m(x^{(T)}) - m(x) ) f( x^{(T)} )] \\ f(x) & = [\exp( m(x^{(1)}) - m(x) ) f( x^{(1)} )\ | \ \ldots \ | \ \exp( m(x^{(T)}) - m(x) ) f( x^{(T)} )] \\
l(x) & = \exp( m(x^{(1)}) - m(x) )\ l(x^{(1)}) + \ldots + \exp( m(x^{(T)}) - m(x) )\ l(x^{(1)}) \\ z(x) & = \exp( m(x^{(1)}) - m(x) )\ z(x^{(1)}) + \ldots + \exp( m(x^{(T)}) - m(x) )\ z(x^{(1)}) \\
\operatorname{softmax}(x) &= f(x)\ / \ l(x) \operatorname{softmax}(x) &= f(x)\ / \ z(x)
\end{align} \end{align}
where :math:`f(x^{(j)}) = \exp( x^{(j)} - m(x^{(j)}) )` is of size :math:`B` and where :math:`f(x^{(j)}) = \exp( x^{(j)} - m(x^{(j)}) )` is of size :math:`B` and
:math:`l(x^{(j)}) = f(x_1^{(j)})+ \ldots+ f(x_B^{(j)})` is a scalar. :math:`z(x^{(j)}) = f(x_1^{(j)})+ \ldots+ f(x_B^{(j)})` is a scalar.
\ No newline at end of file
For a matrix :math:`X` composed of :math:`T_r \times T_c` tiles, :math:`X_{ij}`, of size :math:`B_r \times B_c` we can
compute the row-wise softmax as follows.
For :math:`j` from :math:`1` to :math:`T_c`, and :math:`i` from :math:`1` to :math:`T_r` calculate,
.. math::
:nowrap:
\begin{align}
\tilde{m}_{ij} &= \operatorname{rowmax}( X_{ij} ) \\
\tilde{P}_{ij} &= \exp(X_{ij} - \tilde{m}_{ij} ) \\
\tilde{z}_{ij} &= \operatorname{rowsum}( P_{ij} ) \\
\end{align}
If :math:`j=1`, initialize running max, running sum, and the first column block of the output,
.. math::
:nowrap:
\begin{align}
m_i &= \tilde{m}_{i1} \\
z_i &= \tilde{z}_{i1} \\
\tilde{Y}_{i1} &= \diag(\tilde{z}_{ij})^{-1} \tilde{P}_{i1}
\end{align}
Else if :math:`j>1`,
1. Update running max, running sum and column blocks :math:`k=1` to :math:`k=j-1`
.. math::
:nowrap:
\begin{align}
m^{new}_i &= \max(m_i, \tilde{m}_{ij} ) \\
z^{new}_i &= \exp(m_i - m^{new}_i)\ z_i + \exp( \tilde{m}_{ij} - m^{new}_i )\ \tilde{z}_{ij} \\
Y_{ik} &= \diag(z^{new}_{i})^{-1} \diag(z_{i}) \exp(m_i - m^{new}_i)\ Y_{ik}
\end{align}
2. Initialize column block :math:`j` of output and reset running max and running sum variables:
.. math::
:nowrap:
\begin{align}
\tilde{Y}_{ij} &= \diag(z^{new}_{i})^{-1} \exp(\tilde{m}_{ij} - m^{new}_i ) \tilde{P}_{ij} \\
z_i &= z^{new}_i \\
m_i &= m^{new}_i \\
\end{align}
\ No newline at end of file
...@@ -148,6 +148,13 @@ html_theme_options = { ...@@ -148,6 +148,13 @@ html_theme_options = {
# ] # ]
# } # }
mathjax3_config = {
'tex': {
'macros': {
'diag': '\\operatorname{diag}',
}
}
}
# -- Options for HTMLHelp output ------------------------------------------ # -- Options for HTMLHelp output ------------------------------------------
...@@ -168,7 +175,10 @@ latex_elements = { ...@@ -168,7 +175,10 @@ latex_elements = {
# Additional stuff for the LaTeX preamble. # Additional stuff for the LaTeX preamble.
# #
'preamble': '\setcounter{tocdepth}{5}', 'preamble': r'''
\setcounter{tocdepth}{5}
\newcommand{\diag}{\operatorname{diag}}
''',
# Latex figure (float) alignment # Latex figure (float) alignment
# #
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment