Unverified Commit 70b94685 authored by Shengyu Liu's avatar Shengyu Liu Committed by GitHub
Browse files

Fix LaTeX render error (#74)

parent 6cff5a73
......@@ -25,13 +25,13 @@ To fully utilize GPU compute resources, we need to overlap CUDA Core operations
Our solution involves an additional mathematical transformation beyond FlashAttention's online softmax and accumulation approach. In each step, we take two KV blocks (called $K_0$, $K_1$, $V_0$, and $V_1$). Since the output matrix occupies 32,768 registers (too many for one warpgroup), we split it vertically into $O_L$ and $O_R$ (each $64 \times 256$). We similarly split $V_0$ and $V_1$ into $V_{0L}$, $V_{0R}$, $V_{1L}$, and $V_{1R}$ (each $64 \times 256$). The output matrix is then computed as follows:
0. Maintain a running max $m$ (initialized to $-\infty$, shared between the two warpgroups) and output matrices $\vec o_L, \vec o_R$ (initialized to 0).
1. [0] Compute $\vec p_0 = \vec q K_0^\intercal / qk\_scale$.
2. [1] Compute $\vec p_1 = \vec q K_1^\intercal / qk\_scale$.
3. [0] Compute $mp_0 = \max(\vec p_0)$, $m\_new_0 = \max(m, mp_0)$, and $scale_0 = \exp(m\_new_0 - m)$. Update $m \gets m\_new_0$.
4. [0] Perform softmax on $\vec p_0$: $\vec p_0 \gets \exp(\vec p_0 - m\_new_0)$.
1. [0] Compute $`\vec p_0 = \vec q K_0^\intercal / qk\_scale`$.
2. [1] Compute $`\vec p_1 = \vec q K_1^\intercal / qk\_scale`$.
3. [0] Compute $mp_0 = \max(\vec p_0)$, $`m\_new_0 = \max(m, mp_0)`$, and $`scale_0 = \exp(m\_new_0 - m)`$. Update $`m \gets m\_new_0`$.
4. [0] Perform softmax on $\vec p_0$: $`\vec p_0 \gets \exp(\vec p_0 - m\_new_0)`$.
5. [0] Update $\vec o_L \gets \vec o_L \cdot scale_0 + \vec p_0 V_{0L}$.
6. [1] Compute $mp_1 = \max(\vec p_1)$, $m\_new_1 = \max(m, mp_1)$, and $scale_1 = \exp(m\_new_1 - m)$. Update $m \gets m\_new_1$.
7. [1] Perform softmax on $\vec p_1$: $\vec p_1 \gets \exp(\vec p_1 - m\_new_1)$.
6. [1] Compute $mp_1 = \max(\vec p_1)$, $`m\_new_1 = \max(m, mp_1)`$, and $`scale_1 = \exp(m\_new_1 - m)`$. Update $`m \gets m\_new_1`$.
7. [1] Perform softmax on $\vec p_1$: $`\vec p_1 \gets \exp(\vec p_1 - m\_new_1)`$.
8. [1] Update $\vec o_R \gets \vec o_R \cdot (scale_0 \cdot scale_1) + \vec p_1 V_{1R}$.
9. [0] Update $\vec p_0 \gets \vec p_0 \cdot scale_1$.
10. [1] Update $\vec o_R \gets \vec o_R + \vec p_0 V_{0R}$.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment