Unverified Commit 9d701e90 authored by Steve Westerhouse's avatar Steve Westerhouse Committed by GitHub
Browse files

[Doc] Clarify FP8 KV cache computation workflow (#31071)


Signed-off-by: default avatarwesters <steve.westerhouse@origami-analytics.com>
parent 06d49028
...@@ -139,18 +139,18 @@ token data. ...@@ -139,18 +139,18 @@ token data.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
``` ```
<figure markdown="span"> <p align="center">
![](../assets/design/paged_attention/query.png){ align="center" alt="query" width="70%" } <img src="../assets/design/paged_attention/query.png" alt="query" width="70%" />
</figure> </p>
Each thread defines its own `q_ptr` which points to the assigned Each thread defines its own `q_ptr` which points to the assigned
query token data on global memory. For example, if `VEC_SIZE` is 4 query token data on global memory. For example, if `VEC_SIZE` is 4
and `HEAD_SIZE` is 128, the `q_ptr` points to data that contains and `HEAD_SIZE` is 128, the `q_ptr` points to data that contains
total of 128 elements divided into 128 / 4 = 32 vecs. total of 128 elements divided into 128 / 4 = 32 vecs.
<figure markdown="span"> <p align="center">
![](../assets/design/paged_attention/q_vecs.png){ align="center" alt="q_vecs" width="70%" } <img src="../assets/design/paged_attention/q_vecs.png" alt="q_vecs" width="70%" />
</figure> </p>
```cpp ```cpp
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
...@@ -187,9 +187,9 @@ key token at different iterations. As shown above, that `k_ptr` ...@@ -187,9 +187,9 @@ key token at different iterations. As shown above, that `k_ptr`
points to key token data based on `k_cache` at assigned block, points to key token data based on `k_cache` at assigned block,
assigned head and assigned token. assigned head and assigned token.
<figure markdown="span"> <p align="center">
![](../assets/design/paged_attention/key.png){ align="center" alt="key" width="70%" } <img src="../assets/design/paged_attention/key.png" alt="key" width="70%" />
</figure> </p>
The diagram above illustrates the memory layout for key data. It The diagram above illustrates the memory layout for key data. It
assumes that the `BLOCK_SIZE` is 16, `HEAD_SIZE` is 128, `x` is assumes that the `BLOCK_SIZE` is 16, `HEAD_SIZE` is 128, `x` is
...@@ -202,9 +202,9 @@ iterations. Inside each rectangle, there are a total 32 vecs (128 ...@@ -202,9 +202,9 @@ iterations. Inside each rectangle, there are a total 32 vecs (128
elements for one token) that will be processed by 2 threads (one elements for one token) that will be processed by 2 threads (one
thread group) separately. thread group) separately.
<figure markdown="span"> <p align="center">
![](../assets/design/paged_attention/k_vecs.png){ align="center" alt="k_vecs" width="70%" } <img src="../assets/design/paged_attention/k_vecs.png" alt="k_vecs" width="70%" />
</figure> </p>
```cpp ```cpp
K_vec k_vecs[NUM_VECS_PER_THREAD] K_vec k_vecs[NUM_VECS_PER_THREAD]
...@@ -361,17 +361,17 @@ later steps. Now, it should store the normalized softmax result of ...@@ -361,17 +361,17 @@ later steps. Now, it should store the normalized softmax result of
## Value ## Value
<figure markdown="span"> <p align="center">
![](../assets/design/paged_attention/value.png){ align="center" alt="value" width="70%" } <img src="../assets/design/paged_attention/value.png" alt="value" width="70%" />
</figure> </p>
<figure markdown="span"> <p align="center">
![](../assets/design/paged_attention/logits_vec.png){ align="center" alt="logits_vec" width="50%" } <img src="../assets/design/paged_attention/logits_vec.png" alt="logits_vec" width="50%" />
</figure> </p>
<figure markdown="span"> <p align="center">
![](../assets/design/paged_attention/v_vec.png){ align="center" alt="v_vec" width="70%" } <img src="../assets/design/paged_attention/v_vec.png" alt="v_vec" width="70%" />
</figure> </p>
Now we need to retrieve the value data and perform dot multiplication Now we need to retrieve the value data and perform dot multiplication
with `logits`. Unlike query and key, there is no thread group with `logits`. Unlike query and key, there is no thread group
......
...@@ -17,6 +17,16 @@ The E4M3 format offers higher precision compared to E5M2. However, due to its sm ...@@ -17,6 +17,16 @@ The E4M3 format offers higher precision compared to E5M2. However, due to its sm
For now, only per-tensor (scalar) scaling factors are supported. Development is ongoing to support scaling factors of a finer granularity (e.g. per-channel). For now, only per-tensor (scalar) scaling factors are supported. Development is ongoing to support scaling factors of a finer granularity (e.g. per-channel).
### How FP8 KV Cache Works
The FP8 KV cache implementation follows this workflow:
1. **Storage**: Key and Value tensors are quantized to FP8 format using scaling factors before being stored in the KV cache
2. **Retrieval**: When needed for attention computation, cached KV tensors are dequantized back to higher precision (FP16/BF16)
3. **Attention**: The attention-value multiplication (softmax output × V) is performed using the dequantized higher-precision V tensor
This means the final attention computation operates on dequantized values, not FP8 tensors. The quantization reduces memory usage during storage but maintains computation accuracy by using higher precision during the actual attention operations.
### Performance Impact ### Performance Impact
The current FP8 KV cache implementation primarily benefits throughput by allowing approximately double the amount of space for KV cache allocation. This enables either: The current FP8 KV cache implementation primarily benefits throughput by allowing approximately double the amount of space for KV cache allocation. This enables either:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment