Unverified Commit ef1dd687 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Doc] Fix indentation problems in V0 Paged Attention docs (#18659)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent e77dc4ba
...@@ -9,6 +9,7 @@ Deploying vLLM on Kubernetes is a scalable and efficient way to serve machine le ...@@ -9,6 +9,7 @@ Deploying vLLM on Kubernetes is a scalable and efficient way to serve machine le
* [Deployment with GPUs](#deployment-with-gpus) * [Deployment with GPUs](#deployment-with-gpus)
Alternatively, you can deploy vLLM to Kubernetes using any of the following: Alternatively, you can deploy vLLM to Kubernetes using any of the following:
* [Helm](frameworks/helm.md) * [Helm](frameworks/helm.md)
* [InftyAI/llmaz](integrations/llmaz.md) * [InftyAI/llmaz](integrations/llmaz.md)
* [KServe](integrations/kserve.md) * [KServe](integrations/kserve.md)
......
...@@ -3,78 +3,76 @@ title: vLLM Paged Attention ...@@ -3,78 +3,76 @@ title: vLLM Paged Attention
--- ---
[](){ #design-paged-attention } [](){ #design-paged-attention }
- Currently, vLLM utilizes its own implementation of a multi-head query Currently, vLLM utilizes its own implementation of a multi-head query
attention kernel (`csrc/attention/attention_kernels.cu`). attention kernel (`csrc/attention/attention_kernels.cu`).
This kernel is designed to be compatible with This kernel is designed to be compatible with
vLLM's paged KV caches, where the key and value cache are stored in vLLM's paged KV caches, where the key and value cache are stored in
separate blocks (note that this block concept differs from the GPU separate blocks (note that this block concept differs from the GPU
thread block. So in a later document, I will refer to vLLM paged thread block. So in a later document, I will refer to vLLM paged
attention block as "block", while refer to GPU thread block as attention block as "block", while refer to GPU thread block as
"thread block"). "thread block").
- To achieve high performance, this kernel relies on a specially
designed memory layout and access method, specifically when threads To achieve high performance, this kernel relies on a specially
read data from global memory to shared memory. The purpose of this designed memory layout and access method, specifically when threads
document is to provide a high-level explanation of the kernel read data from global memory to shared memory. The purpose of this
implementation step by step, aiding those who wish to learn about the document is to provide a high-level explanation of the kernel
vLLM multi-head query attention kernel. After going through this implementation step by step, aiding those who wish to learn about the
document, users will likely have a better understanding and feel easier vLLM multi-head query attention kernel. After going through this
to follow the actual implementation. document, users will likely have a better understanding and feel easier
- Please note that this document may not cover all details, such as how to follow the actual implementation.
to calculate the correct index for the corresponding data or the dot
multiplication implementation. However, after reading this document Please note that this document may not cover all details, such as how
and becoming familiar with the high-level logic flow, it should be to calculate the correct index for the corresponding data or the dot
easier for you to read the actual code and understand the details. multiplication implementation. However, after reading this document
and becoming familiar with the high-level logic flow, it should be
easier for you to read the actual code and understand the details.
## Inputs ## Inputs
- The kernel function takes a list of arguments for the current thread The kernel function takes a list of arguments for the current thread
to perform its assigned work. The three most important arguments are to perform its assigned work. The three most important arguments are
the input pointers `q`, `k_cache`, and `v_cache`, which point the input pointers `q`, `k_cache`, and `v_cache`, which point
to query, key, and value data on global memory that need to be read to query, key, and value data on global memory that need to be read
and processed. The output pointer `out` points to global memory and processed. The output pointer `out` points to global memory
where the result should be written. These four pointers actually where the result should be written. These four pointers actually
refer to multi-dimensional arrays, but each thread only accesses the refer to multi-dimensional arrays, but each thread only accesses the
portion of data assigned to it. I have omitted all other runtime portion of data assigned to it. I have omitted all other runtime
parameters here for simplicity. parameters here for simplicity.
```cpp ```cpp
template< template<typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, int PARTITION_SIZE = 0>
typename scalar_t, __device__ void paged_attention_kernel(
int HEAD_SIZE, ... // Other side args.
int BLOCK_SIZE, const scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
int NUM_THREADS, const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
int PARTITION_SIZE = 0> const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
__device__ void paged_attention_kernel( const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
... // Other side args. ... // Other side args.
const scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] )
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] ```
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] There are also a list of template arguments above the function
... // Other side args. signature that are determined during compilation time. `scalar_t`
) represents the data type of the query, key, and value data elements,
``` such as FP16. `HEAD_SIZE` indicates the number of elements in each
head. `BLOCK_SIZE` refers to the number of tokens in each block.
- There are also a list of template arguments above the function `NUM_THREADS` denotes the number of threads in each thread block.
signature that are determined during compilation time. `scalar_t` `PARTITION_SIZE` represents the number of tensor parallel GPUs (For
represents the data type of the query, key, and value data elements, simplicity, we assume this is 0 and tensor parallel is disabled).
such as FP16. `HEAD_SIZE` indicates the number of elements in each
head. `BLOCK_SIZE` refers to the number of tokens in each block. With these arguments, we need to perform a sequence of preparations.
`NUM_THREADS` denotes the number of threads in each thread block. This includes calculating the current head index, block index, and
`PARTITION_SIZE` represents the number of tensor parallel GPUs (For other necessary variables. However, for now, we can ignore these
simplicity, we assume this is 0 and tensor parallel is disabled). preparations and proceed directly to the actual calculations. It will
be easier to understand them once we grasp the entire flow.
- With these arguments, we need to perform a sequence of preparations.
This includes calculating the current head index, block index, and
other necessary variables. However, for now, we can ignore these
preparations and proceed directly to the actual calculations. It will
be easier to understand them once we grasp the entire flow.
## Concepts ## Concepts
- Just before we dive into the calculation flow, I want to describe a Just before we dive into the calculation flow, I want to describe a
few concepts that are needed for later sections. However, you may few concepts that are needed for later sections. However, you may
skip this section and return later if you encounter any confusing skip this section and return later if you encounter any confusing
terminologies. terminologies.
- **Sequence**: A sequence represents a client request. For example, - **Sequence**: A sequence represents a client request. For example,
the data pointed to by `q` has a shape of the data pointed to by `q` has a shape of
`[num_seqs, num_heads, head_size]`. That represents there are total `[num_seqs, num_heads, head_size]`. That represents there are total
...@@ -129,236 +127,236 @@ title: vLLM Paged Attention ...@@ -129,236 +127,236 @@ title: vLLM Paged Attention
## Query ## Query
- This section will introduce how query data is stored in memory and This section will introduce how query data is stored in memory and
fetched by each thread. As mentioned above, each thread group fetches fetched by each thread. As mentioned above, each thread group fetches
one query token data, while each thread itself only handles a part of one query token data, while each thread itself only handles a part of
one query token data. Within each warp, every thread group will fetch one query token data. Within each warp, every thread group will fetch
the same query token data, but will multiply it with different key the same query token data, but will multiply it with different key
token data. token data.
```cpp ```cpp
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
``` ```
<figure markdown="span"> <figure markdown="span">
![](../../assets/kernel/query.png){ align="center" alt="query" width="70%" } ![](../../assets/kernel/query.png){ align="center" alt="query" width="70%" }
</figure> </figure>
- Each thread defines its own `q_ptr` which points to the assigned Each thread defines its own `q_ptr` which points to the assigned
query token data on global memory. For example, if `VEC_SIZE` is 4 query token data on global memory. For example, if `VEC_SIZE` is 4
and `HEAD_SIZE` is 128, the `q_ptr` points to data that contains and `HEAD_SIZE` is 128, the `q_ptr` points to data that contains
total of 128 elements divided into 128 / 4 = 32 vecs. total of 128 elements divided into 128 / 4 = 32 vecs.
<figure markdown="span"> <figure markdown="span">
![](../../assets/kernel/q_vecs.png){ align="center" alt="q_vecs" width="70%" } ![](../../assets/kernel/q_vecs.png){ align="center" alt="q_vecs" width="70%" }
</figure> </figure>
```cpp ```cpp
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
``` ```
- Next, we need to read the global memory data pointed to by `q_ptr` Next, we need to read the global memory data pointed to by `q_ptr`
into shared memory as `q_vecs`. It is important to note that each into shared memory as `q_vecs`. It is important to note that each
vecs is assigned to a different row. For example, if the vecs is assigned to a different row. For example, if the
`THREAD_GROUP_SIZE` is 2, thread 0 will handle the 0th row vecs, `THREAD_GROUP_SIZE` is 2, thread 0 will handle the 0th row vecs,
while thread 1 handles the 1st row vecs. By reading the query data in while thread 1 handles the 1st row vecs. By reading the query data in
this way, neighboring threads like thread 0 and thread 1 can read this way, neighboring threads like thread 0 and thread 1 can read
neighbor memory, achieving the memory coalescing to improve neighbor memory, achieving the memory coalescing to improve
performance. performance.
## Key ## Key
- Similar to the "Query" section, this section introduces memory layout Similar to the "Query" section, this section introduces memory layout
and assignment for keys. While each thread group only handle one and assignment for keys. While each thread group only handle one
query token one kernel run, it may handle multiple key tokens across query token one kernel run, it may handle multiple key tokens across
multiple iterations. Meanwhile, each warp will process multiple blocks multiple iterations. Meanwhile, each warp will process multiple blocks
of key tokens in multiple iterations, ensuring that all context of key tokens in multiple iterations, ensuring that all context
tokens are processed by the entire thread group after the kernel run. tokens are processed by the entire thread group after the kernel run.
In this context, "handle" refers to performing the dot multiplication In this context, "handle" refers to performing the dot multiplication
between query data and key data. between query data and key data.
```cpp ```cpp
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride + kv_head_idx * kv_head_stride
+ physical_block_offset * x; + physical_block_offset * x;
``` ```
- Unlike to `q_ptr`, `k_ptr` in each thread will point to different Unlike to `q_ptr`, `k_ptr` in each thread will point to different
key token at different iterations. As shown above, that `k_ptr` key token at different iterations. As shown above, that `k_ptr`
points to key token data based on `k_cache` at assigned block, points to key token data based on `k_cache` at assigned block,
assigned head and assigned token. assigned head and assigned token.
<figure markdown="span"> <figure markdown="span">
![](../../assets/kernel/key.png){ align="center" alt="key" width="70%" } ![](../../assets/kernel/key.png){ align="center" alt="key" width="70%" }
</figure> </figure>
- The diagram above illustrates the memory layout for key data. It The diagram above illustrates the memory layout for key data. It
assumes that the `BLOCK_SIZE` is 16, `HEAD_SIZE` is 128, `x` is assumes that the `BLOCK_SIZE` is 16, `HEAD_SIZE` is 128, `x` is
8, `THREAD_GROUP_SIZE` is 2, and there are a total of 4 warps. Each 8, `THREAD_GROUP_SIZE` is 2, and there are a total of 4 warps. Each
rectangle represents all the elements for one key token at one head, rectangle represents all the elements for one key token at one head,
which will be processed by one thread group. The left half shows the which will be processed by one thread group. The left half shows the
total 16 blocks of key token data for warp 0, while the right half total 16 blocks of key token data for warp 0, while the right half
represents the remaining key token data for other warps or represents the remaining key token data for other warps or
iterations. Inside each rectangle, there are a total 32 vecs (128 iterations. Inside each rectangle, there are a total 32 vecs (128
elements for one token) that will be processed by 2 threads (one elements for one token) that will be processed by 2 threads (one
thread group) separately. thread group) separately.
<figure markdown="span"> <figure markdown="span">
![](../../assets/kernel/k_vecs.png){ align="center" alt="k_vecs" width="70%" } ![](../../assets/kernel/k_vecs.png){ align="center" alt="k_vecs" width="70%" }
</figure> </figure>
```cpp ```cpp
K_vec k_vecs[NUM_VECS_PER_THREAD] K_vec k_vecs[NUM_VECS_PER_THREAD]
``` ```
- Next, we need to read the key token data from `k_ptr` and store Next, we need to read the key token data from `k_ptr` and store
them on register memory as `k_vecs`. We use register memory for them on register memory as `k_vecs`. We use register memory for
`k_vecs` because it will only be accessed by one thread once, `k_vecs` because it will only be accessed by one thread once,
whereas `q_vecs` will be accessed by multiple threads multiple whereas `q_vecs` will be accessed by multiple threads multiple
times. Each `k_vecs` will contain multiple vectors for later times. Each `k_vecs` will contain multiple vectors for later
calculation. Each vec will be set at each inner iteration. The calculation. Each vec will be set at each inner iteration. The
assignment of vecs allows neighboring threads in a warp to read assignment of vecs allows neighboring threads in a warp to read
neighboring memory together, which again promotes the memory neighboring memory together, which again promotes the memory
coalescing. For instance, thread 0 will read vec 0, while thread 1 coalescing. For instance, thread 0 will read vec 0, while thread 1
will read vec 1. In the next inner loop, thread 0 will read vec 2, will read vec 1. In the next inner loop, thread 0 will read vec 2,
while thread 1 will read vec 3, and so on. while thread 1 will read vec 3, and so on.
- You may still be a little confused about the overall flow. Don't You may still be a little confused about the overall flow. Don't
worry, please keep reading the next "QK" section. It will illustrate worry, please keep reading the next "QK" section. It will illustrate
the query and key calculation flow in a clearer and higher-level the query and key calculation flow in a clearer and higher-level
manner. manner.
## QK ## QK
- As shown the pseudo code below, before the entire for loop block, we As shown the pseudo code below, before the entire for loop block, we
fetch the query data for one token and store it in `q_vecs`. Then, fetch the query data for one token and store it in `q_vecs`. Then,
in the outer for loop, we iterate through different `k_ptrs` that in the outer for loop, we iterate through different `k_ptrs` that
point to different tokens and prepare the `k_vecs` in the inner for point to different tokens and prepare the `k_vecs` in the inner for
loop. Finally, we perform the dot multiplication between the loop. Finally, we perform the dot multiplication between the
`q_vecs` and each `k_vecs`. `q_vecs` and each `k_vecs`.
```cpp ```cpp
q_vecs = ... q_vecs = ...
for ... { for ... {
k_ptr = ... k_ptr = ...
for ... { for ... {
k_vecs[i] = ... k_vecs[i] = ...
} }
... ...
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs); float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
} }
``` ```
- As mentioned before, for each thread, it only fetches part of the As mentioned before, for each thread, it only fetches part of the
query and key token data at a time. However, there will be a cross query and key token data at a time. However, there will be a cross
thread group reduction happen in the `Qk_dot<>::dot` . So `qk` thread group reduction happen in the `Qk_dot<>::dot` . So `qk`
returned here is not just between part of the query and key token dot returned here is not just between part of the query and key token dot
multiplication, but actually a full result between entire query and multiplication, but actually a full result between entire query and
key token data. key token data.
- For example, if the value of `HEAD_SIZE` is 128 and For example, if the value of `HEAD_SIZE` is 128 and
`THREAD_GROUP_SIZE` is 2, each thread's `k_vecs` will contain `THREAD_GROUP_SIZE` is 2, each thread's `k_vecs` will contain
total 64 elements. However, the returned `qk` is actually the total 64 elements. However, the returned `qk` is actually the
result of dot multiplication between 128 query elements and 128 key result of dot multiplication between 128 query elements and 128 key
elements. If you want to learn more about the details of the dot elements. If you want to learn more about the details of the dot
multiplication and reduction, you may refer to the implementation of multiplication and reduction, you may refer to the implementation of
`Qk_dot<>::dot`. However, for the sake of simplicity, I will not `Qk_dot<>::dot`. However, for the sake of simplicity, I will not
cover it in this document. cover it in this document.
## Softmax ## Softmax
- Next, we need to calculate the normalized softmax for all `qk`s, Next, we need to calculate the normalized softmax for all `qk`s,
as shown above, where each $x$ represents a `qk`. To do this, as shown above, where each $x$ represents a `qk`. To do this,
we must obtain the reduced value of `qk_max`($m(x)$) and we must obtain the reduced value of `qk_max`($m(x)$) and
the `exp_sum`($\ell(x)$) of all `qk`s. The reduction the `exp_sum`($\ell(x)$) of all `qk`s. The reduction
should be performed across the entire thread block, encompassing should be performed across the entire thread block, encompassing
results between the query token and all context key tokens. results between the query token and all context key tokens.
$$ $$
\begin{gather*} \begin{gather*}
m(x):=\max _i \quad x_i \\ \quad f(x):=\left[\begin{array}{lll}e^{x_1-m(x)} & \ldots & e^{x_B-m(x)}\end{array}\right]\\ \quad \ell(x):=\sum_i f(x)_i \\ m(x):=\max _i \quad x_i \\ \quad f(x):=\left[\begin{array}{lll}e^{x_1-m(x)} & \ldots & e^{x_B-m(x)}\end{array}\right]\\ \quad \ell(x):=\sum_i f(x)_i \\
\quad \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)} \quad \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)}
\end{gather*} \end{gather*}
$$ $$
### `qk_max` and `logits` ### `qk_max` and `logits`
- Just right after we get the `qk` result, we can set the temporary Just right after we get the `qk` result, we can set the temporary
`logits` result with `qk` (In the end, the `logits` should `logits` result with `qk` (In the end, the `logits` should
store the normalized softmax result). Also we can compare and collect store the normalized softmax result). Also we can compare and collect
the `qk_max` for all `qk`s that are calculated by current the `qk_max` for all `qk`s that are calculated by current
thread group. thread group.
```cpp ```cpp
if (thread_group_offset == 0) { if (thread_group_offset == 0) {
const bool mask = token_idx >= context_len; const bool mask = token_idx >= context_len;
logits[token_idx - start_token_idx] = mask ? 0.f : qk; logits[token_idx - start_token_idx] = mask ? 0.f : qk;
qk_max = mask ? qk_max : fmaxf(qk_max, qk); qk_max = mask ? qk_max : fmaxf(qk_max, qk);
} }
``` ```
- Please note that the `logits` here is on shared memory, so each Please note that the `logits` here is on shared memory, so each
thread group will set the fields for its own assigned context tokens. thread group will set the fields for its own assigned context tokens.
Overall, the size of logits should be number of context tokens. Overall, the size of logits should be number of context tokens.
```cpp ```cpp
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
} }
if (lane == 0) { if (lane == 0) {
red_smem[warp_idx] = qk_max; red_smem[warp_idx] = qk_max;
} }
``` ```
- Then we need to get the reduced `qk_max` across each warp. The main Then we need to get the reduced `qk_max` across each warp. The main
idea is to make threads in warp to communicate with each other and idea is to make threads in warp to communicate with each other and
get the final max `qk` . get the final max `qk` .
```cpp ```cpp
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
} }
qk_max = VLLM_SHFL_SYNC(qk_max, 0); qk_max = VLLM_SHFL_SYNC(qk_max, 0);
``` ```
- Finally, we can get the reduced `qk_max` from whole thread block by Finally, we can get the reduced `qk_max` from whole thread block by
compare the `qk_max` from all warps in this thread block. Then we compare the `qk_max` from all warps in this thread block. Then we
need to broadcast the final result to each thread. need to broadcast the final result to each thread.
### `exp_sum` ### `exp_sum`
- Similar to `qk_max`, we need to get the reduced sum value from the Similar to `qk_max`, we need to get the reduced sum value from the
entire thread block too. entire thread block too.
```cpp ```cpp
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[i] - qk_max); float val = __expf(logits[i] - qk_max);
logits[i] = val; logits[i] = val;
exp_sum += val; exp_sum += val;
} }
... ...
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum); exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
``` ```
- Firstly, sum all exp values from each thread group, and meanwhile, Firstly, sum all exp values from each thread group, and meanwhile,
convert each entry of `logits` from `qk` to `exp(qk - qk_max)`. convert each entry of `logits` from `qk` to `exp(qk - qk_max)`.
Please note, the `qk_max` here is already the max `qk` across the Please note, the `qk_max` here is already the max `qk` across the
whole thread block. And then we can do reduction for `exp_sum` whole thread block. And then we can do reduction for `exp_sum`
across whole thread block just like the `qk_max`. across whole thread block just like the `qk_max`.
```cpp ```cpp
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[i] *= inv_sum; logits[i] *= inv_sum;
} }
``` ```
- Finally, with the reduced `qk_max` and `exp_sum`, we can obtain Finally, with the reduced `qk_max` and `exp_sum`, we can obtain
the final normalized softmax result as `logits`. This `logits` the final normalized softmax result as `logits`. This `logits`
variable will be used for dot multiplication with the value data in variable will be used for dot multiplication with the value data in
later steps. Now, it should store the normalized softmax result of later steps. Now, it should store the normalized softmax result of
`qk` for all assigned context tokens. `qk` for all assigned context tokens.
## Value ## Value
...@@ -374,127 +372,127 @@ title: vLLM Paged Attention ...@@ -374,127 +372,127 @@ title: vLLM Paged Attention
![](../../assets/kernel/v_vec.png){ align="center" alt="v_vec" width="70%" } ![](../../assets/kernel/v_vec.png){ align="center" alt="v_vec" width="70%" }
</figure> </figure>
- Now we need to retrieve the value data and perform dot multiplication Now we need to retrieve the value data and perform dot multiplication
with `logits`. Unlike query and key, there is no thread group with `logits`. Unlike query and key, there is no thread group
concept for value data. As shown in diagram, different from key token concept for value data. As shown in diagram, different from key token
memory layout, elements from the same column correspond to the same memory layout, elements from the same column correspond to the same
value token. For one block of value data, there are `HEAD_SIZE` of value token. For one block of value data, there are `HEAD_SIZE` of
rows and `BLOCK_SIZE` of columns that are split into multiple rows and `BLOCK_SIZE` of columns that are split into multiple
`v_vecs`. `v_vecs`.
- Each thread always fetches `V_VEC_SIZE` elements from the same Each thread always fetches `V_VEC_SIZE` elements from the same
`V_VEC_SIZE` of tokens at a time. As a result, a single thread `V_VEC_SIZE` of tokens at a time. As a result, a single thread
retrieves multiple `v_vec`s from different rows and the same retrieves multiple `v_vec`s from different rows and the same
columns through multiple inner iterations. For each `v_vec`, it columns through multiple inner iterations. For each `v_vec`, it
needs to be dot multiplied with the corresponding `logits_vec`, needs to be dot multiplied with the corresponding `logits_vec`,
which is also `V_VEC_SIZE` elements from `logits`. Overall, with which is also `V_VEC_SIZE` elements from `logits`. Overall, with
multiple inner iterations, each warp will process one block of value multiple inner iterations, each warp will process one block of value
tokens. And with multiple outer iterations, the whole context value tokens. And with multiple outer iterations, the whole context value
tokens are processed tokens are processed
```cpp ```cpp
float accs[NUM_ROWS_PER_THREAD]; float accs[NUM_ROWS_PER_THREAD];
for ... { // Iteration over different blocks. for ... { // Iteration over different blocks.
logits_vec = ... logits_vec = ...
for ... { // Iteration over different rows. for ... { // Iteration over different rows.
v_vec = ... v_vec = ...
... ...
accs[i] += dot(logits_vec, v_vec); accs[i] += dot(logits_vec, v_vec);
} }
} }
``` ```
- As shown in the above pseudo code, in the outer loop, similar to As shown in the above pseudo code, in the outer loop, similar to
`k_ptr`, `logits_vec` iterates over different blocks and reads `k_ptr`, `logits_vec` iterates over different blocks and reads
`V_VEC_SIZE` elements from `logits`. In the inner loop, each `V_VEC_SIZE` elements from `logits`. In the inner loop, each
thread reads `V_VEC_SIZE` elements from the same tokens as a thread reads `V_VEC_SIZE` elements from the same tokens as a
`v_vec` and performs dot multiplication. It is important to note `v_vec` and performs dot multiplication. It is important to note
that in each inner iteration, the thread fetches different head that in each inner iteration, the thread fetches different head
position elements for the same tokens. The dot result is then position elements for the same tokens. The dot result is then
accumulated in `accs`. Therefore, each entry of `accs` is mapped accumulated in `accs`. Therefore, each entry of `accs` is mapped
to a head position assigned to the current thread. to a head position assigned to the current thread.
- For example, if `BLOCK_SIZE` is 16 and `V_VEC_SIZE` is 8, each For example, if `BLOCK_SIZE` is 16 and `V_VEC_SIZE` is 8, each
thread fetches 8 value elements for 8 tokens at a time. Each element thread fetches 8 value elements for 8 tokens at a time. Each element
is from different tokens at the same head position. If `HEAD_SIZE` is from different tokens at the same head position. If `HEAD_SIZE`
is 128 and `WARP_SIZE` is 32, for each inner loop, a warp needs to is 128 and `WARP_SIZE` is 32, for each inner loop, a warp needs to
fetch `WARP_SIZE * V_VEC_SIZE = 256` elements. This means there are fetch `WARP_SIZE * V_VEC_SIZE = 256` elements. This means there are
a total of 128 * 16 / 256 = 8 inner iterations for a warp to handle a total of 128 * 16 / 256 = 8 inner iterations for a warp to handle
a whole block of value tokens. And each `accs` in each thread a whole block of value tokens. And each `accs` in each thread
contains 8 elements that accumulated at 8 different head positions. contains 8 elements that accumulated at 8 different head positions.
For the thread 0, the `accs` variable will have 8 elements, which For the thread 0, the `accs` variable will have 8 elements, which
are 0th, 32th … 224th elements of a value head that are accumulated are 0th, 32th … 224th elements of a value head that are accumulated
from all assigned 8 tokens. from all assigned 8 tokens.
## LV ## LV
- Now, we need to perform reduction for `accs` within each warp. This Now, we need to perform reduction for `accs` within each warp. This
process allows each thread to accumulate the `accs` for the process allows each thread to accumulate the `accs` for the
assigned head positions of all tokens in one block. assigned head positions of all tokens in one block.
```cpp ```cpp
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[i]; float acc = accs[i];
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += VLLM_SHFL_XOR_SYNC(acc, mask); acc += VLLM_SHFL_XOR_SYNC(acc, mask);
} }
accs[i] = acc; accs[i] = acc;
} }
``` ```
- Next, we perform reduction for `accs` across all warps, allowing Next, we perform reduction for `accs` across all warps, allowing
each thread to have the accumulation of `accs` for the assigned each thread to have the accumulation of `accs` for the assigned
head positions of all context tokens. Please note that each `accs` head positions of all context tokens. Please note that each `accs`
in every thread only stores the accumulation for a portion of in every thread only stores the accumulation for a portion of
elements of the entire head for all context tokens. However, overall, elements of the entire head for all context tokens. However, overall,
all results for output have been calculated but are just stored in all results for output have been calculated but are just stored in
different thread register memory. different thread register memory.
```cpp ```cpp
float* out_smem = reinterpret_cast<float*>(shared_mem); float* out_smem = reinterpret_cast<float*>(shared_mem);
for (int i = NUM_WARPS; i > 1; i /= 2) { for (int i = NUM_WARPS; i > 1; i /= 2) {
// Upper warps write to shared memory. // Upper warps write to shared memory.
... ...
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
... ...
dst[row_idx] = accs[i]; dst[row_idx] = accs[i];
} }
// Lower warps update the output. // Lower warps update the output.
const float* src = &out_smem[warp_idx * HEAD_SIZE]; const float* src = &out_smem[warp_idx * HEAD_SIZE];
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
... ...
accs[i] += src[row_idx]; accs[i] += src[row_idx];
} }
// Write out the accs. // Write out the accs.
} }
``` ```
## Output ## Output
- Now we can write all of calculated result from local register memory Now we can write all of calculated result from local register memory
to final output global memory. to final output global memory.
```cpp ```cpp
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+ head_idx * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE
+ partition_idx * HEAD_SIZE; + partition_idx * HEAD_SIZE;
``` ```
- First, we need to define the `out_ptr` variable, which points to First, we need to define the `out_ptr` variable, which points to
the start address of the assigned sequence and assigned head. the start address of the assigned sequence and assigned head.
```cpp ```cpp
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
from_float(*(out_ptr + row_idx), accs[i]); from_float(*(out_ptr + row_idx), accs[i]);
} }
} }
``` ```
- Finally, we need to iterate over different assigned head positions Finally, we need to iterate over different assigned head positions
and write out the corresponding accumulated result based on the and write out the corresponding accumulated result based on the
`out_ptr`. `out_ptr`.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment