Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ef1dd687
Unverified
Commit
ef1dd687
authored
May 24, 2025
by
Cyrus Leung
Committed by
GitHub
May 24, 2025
Browse files
[Doc] Fix indentation problems in V0 Paged Attention docs (#18659)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
e77dc4ba
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
372 additions
and
373 deletions
+372
-373
docs/deployment/k8s.md
docs/deployment/k8s.md
+1
-0
docs/design/kernel/paged_attention.md
docs/design/kernel/paged_attention.md
+371
-373
No files found.
docs/deployment/k8s.md
View file @
ef1dd687
...
@@ -9,6 +9,7 @@ Deploying vLLM on Kubernetes is a scalable and efficient way to serve machine le
...
@@ -9,6 +9,7 @@ Deploying vLLM on Kubernetes is a scalable and efficient way to serve machine le
*
[
Deployment with GPUs
](
#deployment-with-gpus
)
*
[
Deployment with GPUs
](
#deployment-with-gpus
)
Alternatively, you can deploy vLLM to Kubernetes using any of the following:
Alternatively, you can deploy vLLM to Kubernetes using any of the following:
*
[
Helm
](
frameworks/helm.md
)
*
[
Helm
](
frameworks/helm.md
)
*
[
InftyAI/llmaz
](
integrations/llmaz.md
)
*
[
InftyAI/llmaz
](
integrations/llmaz.md
)
*
[
KServe
](
integrations/kserve.md
)
*
[
KServe
](
integrations/kserve.md
)
...
...
docs/design/kernel/paged_attention.md
View file @
ef1dd687
...
@@ -3,78 +3,76 @@ title: vLLM Paged Attention
...
@@ -3,78 +3,76 @@ title: vLLM Paged Attention
---
---
[](
){
#design-paged-attention }
[](
){
#design-paged-attention }
-
Currently, vLLM utilizes its own implementation of a multi-head query
Currently, vLLM utilizes its own implementation of a multi-head query
attention kernel (
`csrc/attention/attention_kernels.cu`
).
attention kernel (
`csrc/attention/attention_kernels.cu`
).
This kernel is designed to be compatible with
This kernel is designed to be compatible with
vLLM's paged KV caches, where the key and value cache are stored in
vLLM's paged KV caches, where the key and value cache are stored in
separate blocks (note that this block concept differs from the GPU
separate blocks (note that this block concept differs from the GPU
thread block. So in a later document, I will refer to vLLM paged
thread block. So in a later document, I will refer to vLLM paged
attention block as "block", while refer to GPU thread block as
attention block as "block", while refer to GPU thread block as
"thread block").
"thread block").
-
To achieve high performance, this kernel relies on a specially
designed memory layout and access method, specifically when threads
To achieve high performance, this kernel relies on a specially
read data from global memory to shared memory. The purpose of this
designed memory layout and access method, specifically when threads
document is to provide a high-level explanation of the kernel
read data from global memory to shared memory. The purpose of this
implementation step by step, aiding those who wish to learn about the
document is to provide a high-level explanation of the kernel
vLLM multi-head query attention kernel. After going through this
implementation step by step, aiding those who wish to learn about the
document, users will likely have a better understanding and feel easier
vLLM multi-head query attention kernel. After going through this
to follow the actual implementation.
document, users will likely have a better understanding and feel easier
-
Please note that this document may not cover all details, such as how
to follow the actual implementation.
to calculate the correct index for the corresponding data or the dot
multiplication implementation. However, after reading this document
Please note that this document may not cover all details, such as how
and becoming familiar with the high-level logic flow, it should be
to calculate the correct index for the corresponding data or the dot
easier for you to read the actual code and understand the details.
multiplication implementation. However, after reading this document
and becoming familiar with the high-level logic flow, it should be
easier for you to read the actual code and understand the details.
## Inputs
## Inputs
-
The kernel function takes a list of arguments for the current thread
The kernel function takes a list of arguments for the current thread
to perform its assigned work. The three most important arguments are
to perform its assigned work. The three most important arguments are
the input pointers
`q`
,
`k_cache`
, and
`v_cache`
, which point
the input pointers
`q`
,
`k_cache`
, and
`v_cache`
, which point
to query, key, and value data on global memory that need to be read
to query, key, and value data on global memory that need to be read
and processed. The output pointer
`out`
points to global memory
and processed. The output pointer
`out`
points to global memory
where the result should be written. These four pointers actually
where the result should be written. These four pointers actually
refer to multi-dimensional arrays, but each thread only accesses the
refer to multi-dimensional arrays, but each thread only accesses the
portion of data assigned to it. I have omitted all other runtime
portion of data assigned to it. I have omitted all other runtime
parameters here for simplicity.
parameters here for simplicity.
```
cpp
```
cpp
template
<
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
=
0
>
typename
scalar_t
,
__device__
void
paged_attention_kernel
(
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
=
0
>
__device__
void
paged_attention_kernel
(
...
// Other side args.
...
// Other side args.
const
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
const
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads, head_size/x, block_size, x]
const
scalar_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads, head_size/x, block_size, x]
const
scalar_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads, head_size, block_size]
const
scalar_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads, head_size, block_size]
...
// Other side args.
...
// Other side args.
)
)
```
```
-
There are also a list of template arguments above the function
There are also a list of template arguments above the function
signature that are determined during compilation time.
`scalar_t`
signature that are determined during compilation time.
`scalar_t`
represents the data type of the query, key, and value data elements,
represents the data type of the query, key, and value data elements,
such as FP16.
`HEAD_SIZE`
indicates the number of elements in each
such as FP16.
`HEAD_SIZE`
indicates the number of elements in each
head.
`BLOCK_SIZE`
refers to the number of tokens in each block.
head.
`BLOCK_SIZE`
refers to the number of tokens in each block.
`NUM_THREADS`
denotes the number of threads in each thread block.
`NUM_THREADS`
denotes the number of threads in each thread block.
`PARTITION_SIZE`
represents the number of tensor parallel GPUs (For
`PARTITION_SIZE`
represents the number of tensor parallel GPUs (For
simplicity, we assume this is 0 and tensor parallel is disabled).
simplicity, we assume this is 0 and tensor parallel is disabled).
-
With these arguments, we need to perform a sequence of preparations.
With these arguments, we need to perform a sequence of preparations.
This includes calculating the current head index, block index, and
This includes calculating the current head index, block index, and
other necessary variables. However, for now, we can ignore these
other necessary variables. However, for now, we can ignore these
preparations and proceed directly to the actual calculations. It will
preparations and proceed directly to the actual calculations. It will
be easier to understand them once we grasp the entire flow.
be easier to understand them once we grasp the entire flow.
## Concepts
## Concepts
-
Just before we dive into the calculation flow, I want to describe a
Just before we dive into the calculation flow, I want to describe a
few concepts that are needed for later sections. However, you may
few concepts that are needed for later sections. However, you may
skip this section and return later if you encounter any confusing
skip this section and return later if you encounter any confusing
terminologies.
terminologies.
-
**Sequence**
: A sequence represents a client request. For example,
-
**Sequence**
: A sequence represents a client request. For example,
the data pointed to by
`q`
has a shape of
the data pointed to by
`q`
has a shape of
`[num_seqs, num_heads, head_size]`
. That represents there are total
`[num_seqs, num_heads, head_size]`
. That represents there are total
...
@@ -129,236 +127,236 @@ title: vLLM Paged Attention
...
@@ -129,236 +127,236 @@ title: vLLM Paged Attention
## Query
## Query
-
This section will introduce how query data is stored in memory and
This section will introduce how query data is stored in memory and
fetched by each thread. As mentioned above, each thread group fetches
fetched by each thread. As mentioned above, each thread group fetches
one query token data, while each thread itself only handles a part of
one query token data, while each thread itself only handles a part of
one query token data. Within each warp, every thread group will fetch
one query token data. Within each warp, every thread group will fetch
the same query token data, but will multiply it with different key
the same query token data, but will multiply it with different key
token data.
token data.
```
cpp
```
cpp
const
scalar_t
*
q_ptr
=
q
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
const
scalar_t
*
q_ptr
=
q
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
```
```
<figure
markdown=
"span"
>
<figure
markdown=
"span"
>
!
[](
../../assets/kernel/query.png
)
{ align="center" alt="query" width="70%" }
!
[](
../../assets/kernel/query.png
)
{ align="center" alt="query" width="70%" }
</figure>
</figure>
-
Each thread defines its own
`q_ptr`
which points to the assigned
Each thread defines its own
`q_ptr`
which points to the assigned
query token data on global memory. For example, if
`VEC_SIZE`
is 4
query token data on global memory. For example, if
`VEC_SIZE`
is 4
and
`HEAD_SIZE`
is 128, the
`q_ptr`
points to data that contains
and
`HEAD_SIZE`
is 128, the
`q_ptr`
points to data that contains
total of 128 elements divided into 128 / 4 = 32 vecs.
total of 128 elements divided into 128 / 4 = 32 vecs.
<figure
markdown=
"span"
>
<figure
markdown=
"span"
>
!
[](
../../assets/kernel/q_vecs.png
)
{ align="center" alt="q_vecs" width="70%" }
!
[](
../../assets/kernel/q_vecs.png
)
{ align="center" alt="q_vecs" width="70%" }
</figure>
</figure>
```
cpp
```
cpp
__shared__
Q_vec
q_vecs
[
THREAD_GROUP_SIZE
][
NUM_VECS_PER_THREAD
];
__shared__
Q_vec
q_vecs
[
THREAD_GROUP_SIZE
][
NUM_VECS_PER_THREAD
];
```
```
-
Next, we need to read the global memory data pointed to by
`q_ptr`
Next, we need to read the global memory data pointed to by
`q_ptr`
into shared memory as
`q_vecs`
. It is important to note that each
into shared memory as
`q_vecs`
. It is important to note that each
vecs is assigned to a different row. For example, if the
vecs is assigned to a different row. For example, if the
`THREAD_GROUP_SIZE`
is 2, thread 0 will handle the 0th row vecs,
`THREAD_GROUP_SIZE`
is 2, thread 0 will handle the 0th row vecs,
while thread 1 handles the 1st row vecs. By reading the query data in
while thread 1 handles the 1st row vecs. By reading the query data in
this way, neighboring threads like thread 0 and thread 1 can read
this way, neighboring threads like thread 0 and thread 1 can read
neighbor memory, achieving the memory coalescing to improve
neighbor memory, achieving the memory coalescing to improve
performance.
performance.
## Key
## Key
-
Similar to the "Query" section, this section introduces memory layout
Similar to the "Query" section, this section introduces memory layout
and assignment for keys. While each thread group only handle one
and assignment for keys. While each thread group only handle one
query token one kernel run, it may handle multiple key tokens across
query token one kernel run, it may handle multiple key tokens across
multiple iterations. Meanwhile, each warp will process multiple blocks
multiple iterations. Meanwhile, each warp will process multiple blocks
of key tokens in multiple iterations, ensuring that all context
of key tokens in multiple iterations, ensuring that all context
tokens are processed by the entire thread group after the kernel run.
tokens are processed by the entire thread group after the kernel run.
In this context, "handle" refers to performing the dot multiplication
In this context, "handle" refers to performing the dot multiplication
between query data and key data.
between query data and key data.
```
cpp
```
cpp
const
scalar_t
*
k_ptr
=
k_cache
+
physical_block_number
*
kv_block_stride
const
scalar_t
*
k_ptr
=
k_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
kv_head_idx
*
kv_head_stride
+
physical_block_offset
*
x
;
+
physical_block_offset
*
x
;
```
```
-
Unlike to
`q_ptr`
,
`k_ptr`
in each thread will point to different
Unlike to
`q_ptr`
,
`k_ptr`
in each thread will point to different
key token at different iterations. As shown above, that
`k_ptr`
key token at different iterations. As shown above, that
`k_ptr`
points to key token data based on
`k_cache`
at assigned block,
points to key token data based on
`k_cache`
at assigned block,
assigned head and assigned token.
assigned head and assigned token.
<figure
markdown=
"span"
>
<figure
markdown=
"span"
>
!
[](
../../assets/kernel/key.png
)
{ align="center" alt="key" width="70%" }
!
[](
../../assets/kernel/key.png
)
{ align="center" alt="key" width="70%" }
</figure>
</figure>
-
The diagram above illustrates the memory layout for key data. It
The diagram above illustrates the memory layout for key data. It
assumes that the
`BLOCK_SIZE`
is 16,
`HEAD_SIZE`
is 128,
`x`
is
assumes that the
`BLOCK_SIZE`
is 16,
`HEAD_SIZE`
is 128,
`x`
is
8,
`THREAD_GROUP_SIZE`
is 2, and there are a total of 4 warps. Each
8,
`THREAD_GROUP_SIZE`
is 2, and there are a total of 4 warps. Each
rectangle represents all the elements for one key token at one head,
rectangle represents all the elements for one key token at one head,
which will be processed by one thread group. The left half shows the
which will be processed by one thread group. The left half shows the
total 16 blocks of key token data for warp 0, while the right half
total 16 blocks of key token data for warp 0, while the right half
represents the remaining key token data for other warps or
represents the remaining key token data for other warps or
iterations. Inside each rectangle, there are a total 32 vecs (128
iterations. Inside each rectangle, there are a total 32 vecs (128
elements for one token) that will be processed by 2 threads (one
elements for one token) that will be processed by 2 threads (one
thread group) separately.
thread group) separately.
<figure
markdown=
"span"
>
<figure
markdown=
"span"
>
!
[](
../../assets/kernel/k_vecs.png
)
{ align="center" alt="k_vecs" width="70%" }
!
[](
../../assets/kernel/k_vecs.png
)
{ align="center" alt="k_vecs" width="70%" }
</figure>
</figure>
```
cpp
```
cpp
K_vec
k_vecs
[
NUM_VECS_PER_THREAD
]
K_vec
k_vecs
[
NUM_VECS_PER_THREAD
]
```
```
-
Next, we need to read the key token data from
`k_ptr`
and store
Next, we need to read the key token data from
`k_ptr`
and store
them on register memory as
`k_vecs`
. We use register memory for
them on register memory as
`k_vecs`
. We use register memory for
`k_vecs`
because it will only be accessed by one thread once,
`k_vecs`
because it will only be accessed by one thread once,
whereas
`q_vecs`
will be accessed by multiple threads multiple
whereas
`q_vecs`
will be accessed by multiple threads multiple
times. Each
`k_vecs`
will contain multiple vectors for later
times. Each
`k_vecs`
will contain multiple vectors for later
calculation. Each vec will be set at each inner iteration. The
calculation. Each vec will be set at each inner iteration. The
assignment of vecs allows neighboring threads in a warp to read
assignment of vecs allows neighboring threads in a warp to read
neighboring memory together, which again promotes the memory
neighboring memory together, which again promotes the memory
coalescing. For instance, thread 0 will read vec 0, while thread 1
coalescing. For instance, thread 0 will read vec 0, while thread 1
will read vec 1. In the next inner loop, thread 0 will read vec 2,
will read vec 1. In the next inner loop, thread 0 will read vec 2,
while thread 1 will read vec 3, and so on.
while thread 1 will read vec 3, and so on.
-
You may still be a little confused about the overall flow. Don't
You may still be a little confused about the overall flow. Don't
worry, please keep reading the next "QK" section. It will illustrate
worry, please keep reading the next "QK" section. It will illustrate
the query and key calculation flow in a clearer and higher-level
the query and key calculation flow in a clearer and higher-level
manner.
manner.
## QK
## QK
-
As shown the pseudo code below, before the entire for loop block, we
As shown the pseudo code below, before the entire for loop block, we
fetch the query data for one token and store it in
`q_vecs`
. Then,
fetch the query data for one token and store it in
`q_vecs`
. Then,
in the outer for loop, we iterate through different
`k_ptrs`
that
in the outer for loop, we iterate through different
`k_ptrs`
that
point to different tokens and prepare the
`k_vecs`
in the inner for
point to different tokens and prepare the
`k_vecs`
in the inner for
loop. Finally, we perform the dot multiplication between the
loop. Finally, we perform the dot multiplication between the
`q_vecs`
and each
`k_vecs`
.
`q_vecs`
and each
`k_vecs`
.
```
cpp
```
cpp
q_vecs
=
...
q_vecs
=
...
for
...
{
for
...
{
k_ptr
=
...
k_ptr
=
...
for
...
{
for
...
{
k_vecs
[
i
]
=
...
k_vecs
[
i
]
=
...
}
}
...
...
float
qk
=
scale
*
Qk_dot
<
scalar_t
,
THREAD_GROUP_SIZE
>::
dot
(
q_vecs
[
thread_group_offset
],
k_vecs
);
float
qk
=
scale
*
Qk_dot
<
scalar_t
,
THREAD_GROUP_SIZE
>::
dot
(
q_vecs
[
thread_group_offset
],
k_vecs
);
}
}
```
```
-
As mentioned before, for each thread, it only fetches part of the
As mentioned before, for each thread, it only fetches part of the
query and key token data at a time. However, there will be a cross
query and key token data at a time. However, there will be a cross
thread group reduction happen in the
`Qk_dot<>::dot`
. So
`qk`
thread group reduction happen in the
`Qk_dot<>::dot`
. So
`qk`
returned here is not just between part of the query and key token dot
returned here is not just between part of the query and key token dot
multiplication, but actually a full result between entire query and
multiplication, but actually a full result between entire query and
key token data.
key token data.
-
For example, if the value of
`HEAD_SIZE`
is 128 and
For example, if the value of
`HEAD_SIZE`
is 128 and
`THREAD_GROUP_SIZE`
is 2, each thread's
`k_vecs`
will contain
`THREAD_GROUP_SIZE`
is 2, each thread's
`k_vecs`
will contain
total 64 elements. However, the returned
`qk`
is actually the
total 64 elements. However, the returned
`qk`
is actually the
result of dot multiplication between 128 query elements and 128 key
result of dot multiplication between 128 query elements and 128 key
elements. If you want to learn more about the details of the dot
elements. If you want to learn more about the details of the dot
multiplication and reduction, you may refer to the implementation of
multiplication and reduction, you may refer to the implementation of
`Qk_dot<>::dot`
. However, for the sake of simplicity, I will not
`Qk_dot<>::dot`
. However, for the sake of simplicity, I will not
cover it in this document.
cover it in this document.
## Softmax
## Softmax
-
Next, we need to calculate the normalized softmax for all
`qk`
s,
Next, we need to calculate the normalized softmax for all
`qk`
s,
as shown above, where each $x$ represents a
`qk`
. To do this,
as shown above, where each $x$ represents a
`qk`
. To do this,
we must obtain the reduced value of
`qk_max`
($m(x)$) and
we must obtain the reduced value of
`qk_max`
($m(x)$) and
the
`exp_sum`
($
\e
ll(x)$) of all
`qk`
s. The reduction
the
`exp_sum`
($
\e
ll(x)$) of all
`qk`
s. The reduction
should be performed across the entire thread block, encompassing
should be performed across the entire thread block, encompassing
results between the query token and all context key tokens.
results between the query token and all context key tokens.
$$
$$
\b
egin{gather
*
}
\b
egin{gather
*
}
m(x):=
\m
ax _i
\q
uad x_i
\\
\q
uad f(x):=
\l
eft[
\b
egin{array}{lll}e^{x_1-m(x)} &
\l
dots & e^{x_B-m(x)}
\e
nd{array}
\r
ight]
\\
\q
uad
\e
ll(x):=
\s
um_i f(x)_i
\\
m(x):=
\m
ax _i
\q
uad x_i
\\
\q
uad f(x):=
\l
eft[
\b
egin{array}{lll}e^{x_1-m(x)} &
\l
dots & e^{x_B-m(x)}
\e
nd{array}
\r
ight]
\\
\q
uad
\e
ll(x):=
\s
um_i f(x)_i
\\
\q
uad
\o
peratorname{softmax}(x):=
\f
rac{f(x)}{
\e
ll(x)}
\q
uad
\o
peratorname{softmax}(x):=
\f
rac{f(x)}{
\e
ll(x)}
\e
nd{gather
*
}
\e
nd{gather
*
}
$$
$$
### `qk_max` and `logits`
### `qk_max` and `logits`
-
Just right after we get the
`qk`
result, we can set the temporary
Just right after we get the
`qk`
result, we can set the temporary
`logits`
result with
`qk`
(In the end, the
`logits`
should
`logits`
result with
`qk`
(In the end, the
`logits`
should
store the normalized softmax result). Also we can compare and collect
store the normalized softmax result). Also we can compare and collect
the
`qk_max`
for all
`qk`
s that are calculated by current
the
`qk_max`
for all
`qk`
s that are calculated by current
thread group.
thread group.
```
cpp
```
cpp
if
(
thread_group_offset
==
0
)
{
if
(
thread_group_offset
==
0
)
{
const
bool
mask
=
token_idx
>=
context_len
;
const
bool
mask
=
token_idx
>=
context_len
;
logits
[
token_idx
-
start_token_idx
]
=
mask
?
0.
f
:
qk
;
logits
[
token_idx
-
start_token_idx
]
=
mask
?
0.
f
:
qk
;
qk_max
=
mask
?
qk_max
:
fmaxf
(
qk_max
,
qk
);
qk_max
=
mask
?
qk_max
:
fmaxf
(
qk_max
,
qk
);
}
}
```
```
-
Please note that the
`logits`
here is on shared memory, so each
Please note that the
`logits`
here is on shared memory, so each
thread group will set the fields for its own assigned context tokens.
thread group will set the fields for its own assigned context tokens.
Overall, the size of logits should be number of context tokens.
Overall, the size of logits should be number of context tokens.
```
cpp
```
cpp
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
THREAD_GROUP_SIZE
;
mask
/=
2
)
{
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
THREAD_GROUP_SIZE
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
VLLM_SHFL_XOR_SYNC
(
qk_max
,
mask
));
qk_max
=
fmaxf
(
qk_max
,
VLLM_SHFL_XOR_SYNC
(
qk_max
,
mask
));
}
}
if
(
lane
==
0
)
{
if
(
lane
==
0
)
{
red_smem
[
warp_idx
]
=
qk_max
;
red_smem
[
warp_idx
]
=
qk_max
;
}
}
```
```
-
Then we need to get the reduced
`qk_max`
across each warp. The main
Then we need to get the reduced
`qk_max`
across each warp. The main
idea is to make threads in warp to communicate with each other and
idea is to make threads in warp to communicate with each other and
get the final max
`qk`
.
get the final max
`qk`
.
```
cpp
```
cpp
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
VLLM_SHFL_XOR_SYNC
(
qk_max
,
mask
));
qk_max
=
fmaxf
(
qk_max
,
VLLM_SHFL_XOR_SYNC
(
qk_max
,
mask
));
}
}
qk_max
=
VLLM_SHFL_SYNC
(
qk_max
,
0
);
qk_max
=
VLLM_SHFL_SYNC
(
qk_max
,
0
);
```
```
-
Finally, we can get the reduced
`qk_max`
from whole thread block by
Finally, we can get the reduced
`qk_max`
from whole thread block by
compare the
`qk_max`
from all warps in this thread block. Then we
compare the
`qk_max`
from all warps in this thread block. Then we
need to broadcast the final result to each thread.
need to broadcast the final result to each thread.
### `exp_sum`
### `exp_sum`
-
Similar to
`qk_max`
, we need to get the reduced sum value from the
Similar to
`qk_max`
, we need to get the reduced sum value from the
entire thread block too.
entire thread block too.
```
cpp
```
cpp
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
float
val
=
__expf
(
logits
[
i
]
-
qk_max
);
float
val
=
__expf
(
logits
[
i
]
-
qk_max
);
logits
[
i
]
=
val
;
logits
[
i
]
=
val
;
exp_sum
+=
val
;
exp_sum
+=
val
;
}
}
...
...
exp_sum
=
block_sum
<
NUM_WARPS
>
(
&
red_smem
[
NUM_WARPS
],
exp_sum
);
exp_sum
=
block_sum
<
NUM_WARPS
>
(
&
red_smem
[
NUM_WARPS
],
exp_sum
);
```
```
-
Firstly, sum all exp values from each thread group, and meanwhile,
Firstly, sum all exp values from each thread group, and meanwhile,
convert each entry of
`logits`
from
`qk`
to
`exp(qk - qk_max)`
.
convert each entry of
`logits`
from
`qk`
to
`exp(qk - qk_max)`
.
Please note, the
`qk_max`
here is already the max
`qk`
across the
Please note, the
`qk_max`
here is already the max
`qk`
across the
whole thread block. And then we can do reduction for
`exp_sum`
whole thread block. And then we can do reduction for
`exp_sum`
across whole thread block just like the
`qk_max`
.
across whole thread block just like the
`qk_max`
.
```
cpp
```
cpp
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
+
1e-6
f
);
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
+
1e-6
f
);
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
logits
[
i
]
*=
inv_sum
;
logits
[
i
]
*=
inv_sum
;
}
}
```
```
-
Finally, with the reduced
`qk_max`
and
`exp_sum`
, we can obtain
Finally, with the reduced
`qk_max`
and
`exp_sum`
, we can obtain
the final normalized softmax result as
`logits`
. This
`logits`
the final normalized softmax result as
`logits`
. This
`logits`
variable will be used for dot multiplication with the value data in
variable will be used for dot multiplication with the value data in
later steps. Now, it should store the normalized softmax result of
later steps. Now, it should store the normalized softmax result of
`qk`
for all assigned context tokens.
`qk`
for all assigned context tokens.
## Value
## Value
...
@@ -374,85 +372,85 @@ title: vLLM Paged Attention
...
@@ -374,85 +372,85 @@ title: vLLM Paged Attention
!
[](
../../assets/kernel/v_vec.png
)
{ align="center" alt="v_vec" width="70%" }
!
[](
../../assets/kernel/v_vec.png
)
{ align="center" alt="v_vec" width="70%" }
</figure>
</figure>
-
Now we need to retrieve the value data and perform dot multiplication
Now we need to retrieve the value data and perform dot multiplication
with
`logits`
. Unlike query and key, there is no thread group
with
`logits`
. Unlike query and key, there is no thread group
concept for value data. As shown in diagram, different from key token
concept for value data. As shown in diagram, different from key token
memory layout, elements from the same column correspond to the same
memory layout, elements from the same column correspond to the same
value token. For one block of value data, there are
`HEAD_SIZE`
of
value token. For one block of value data, there are
`HEAD_SIZE`
of
rows and
`BLOCK_SIZE`
of columns that are split into multiple
rows and
`BLOCK_SIZE`
of columns that are split into multiple
`v_vecs`
.
`v_vecs`
.
-
Each thread always fetches
`V_VEC_SIZE`
elements from the same
Each thread always fetches
`V_VEC_SIZE`
elements from the same
`V_VEC_SIZE`
of tokens at a time. As a result, a single thread
`V_VEC_SIZE`
of tokens at a time. As a result, a single thread
retrieves multiple
`v_vec`
s from different rows and the same
retrieves multiple
`v_vec`
s from different rows and the same
columns through multiple inner iterations. For each
`v_vec`
, it
columns through multiple inner iterations. For each
`v_vec`
, it
needs to be dot multiplied with the corresponding
`logits_vec`
,
needs to be dot multiplied with the corresponding
`logits_vec`
,
which is also
`V_VEC_SIZE`
elements from
`logits`
. Overall, with
which is also
`V_VEC_SIZE`
elements from
`logits`
. Overall, with
multiple inner iterations, each warp will process one block of value
multiple inner iterations, each warp will process one block of value
tokens. And with multiple outer iterations, the whole context value
tokens. And with multiple outer iterations, the whole context value
tokens are processed
tokens are processed
```
cpp
```
cpp
float
accs
[
NUM_ROWS_PER_THREAD
];
float
accs
[
NUM_ROWS_PER_THREAD
];
for
...
{
// Iteration over different blocks.
for
...
{
// Iteration over different blocks.
logits_vec
=
...
logits_vec
=
...
for
...
{
// Iteration over different rows.
for
...
{
// Iteration over different rows.
v_vec
=
...
v_vec
=
...
...
...
accs
[
i
]
+=
dot
(
logits_vec
,
v_vec
);
accs
[
i
]
+=
dot
(
logits_vec
,
v_vec
);
}
}
}
}
```
```
-
As shown in the above pseudo code, in the outer loop, similar to
As shown in the above pseudo code, in the outer loop, similar to
`k_ptr`
,
`logits_vec`
iterates over different blocks and reads
`k_ptr`
,
`logits_vec`
iterates over different blocks and reads
`V_VEC_SIZE`
elements from
`logits`
. In the inner loop, each
`V_VEC_SIZE`
elements from
`logits`
. In the inner loop, each
thread reads
`V_VEC_SIZE`
elements from the same tokens as a
thread reads
`V_VEC_SIZE`
elements from the same tokens as a
`v_vec`
and performs dot multiplication. It is important to note
`v_vec`
and performs dot multiplication. It is important to note
that in each inner iteration, the thread fetches different head
that in each inner iteration, the thread fetches different head
position elements for the same tokens. The dot result is then
position elements for the same tokens. The dot result is then
accumulated in
`accs`
. Therefore, each entry of
`accs`
is mapped
accumulated in
`accs`
. Therefore, each entry of
`accs`
is mapped
to a head position assigned to the current thread.
to a head position assigned to the current thread.
-
For example, if
`BLOCK_SIZE`
is 16 and
`V_VEC_SIZE`
is 8, each
For example, if
`BLOCK_SIZE`
is 16 and
`V_VEC_SIZE`
is 8, each
thread fetches 8 value elements for 8 tokens at a time. Each element
thread fetches 8 value elements for 8 tokens at a time. Each element
is from different tokens at the same head position. If
`HEAD_SIZE`
is from different tokens at the same head position. If
`HEAD_SIZE`
is 128 and
`WARP_SIZE`
is 32, for each inner loop, a warp needs to
is 128 and
`WARP_SIZE`
is 32, for each inner loop, a warp needs to
fetch
`WARP_SIZE * V_VEC_SIZE = 256`
elements. This means there are
fetch
`WARP_SIZE * V_VEC_SIZE = 256`
elements. This means there are
a total of 128
*
16 / 256 = 8 inner iterations for a warp to handle
a total of 128
*
16 / 256 = 8 inner iterations for a warp to handle
a whole block of value tokens. And each
`accs`
in each thread
a whole block of value tokens. And each
`accs`
in each thread
contains 8 elements that accumulated at 8 different head positions.
contains 8 elements that accumulated at 8 different head positions.
For the thread 0, the
`accs`
variable will have 8 elements, which
For the thread 0, the
`accs`
variable will have 8 elements, which
are 0th, 32th … 224th elements of a value head that are accumulated
are 0th, 32th … 224th elements of a value head that are accumulated
from all assigned 8 tokens.
from all assigned 8 tokens.
## LV
## LV
-
Now, we need to perform reduction for
`accs`
within each warp. This
Now, we need to perform reduction for
`accs`
within each warp. This
process allows each thread to accumulate the
`accs`
for the
process allows each thread to accumulate the
`accs`
for the
assigned head positions of all tokens in one block.
assigned head positions of all tokens in one block.
```
cpp
```
cpp
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
float
acc
=
accs
[
i
];
float
acc
=
accs
[
i
];
for
(
int
mask
=
NUM_V_VECS_PER_ROW
/
2
;
mask
>=
1
;
mask
/=
2
)
{
for
(
int
mask
=
NUM_V_VECS_PER_ROW
/
2
;
mask
>=
1
;
mask
/=
2
)
{
acc
+=
VLLM_SHFL_XOR_SYNC
(
acc
,
mask
);
acc
+=
VLLM_SHFL_XOR_SYNC
(
acc
,
mask
);
}
}
accs
[
i
]
=
acc
;
accs
[
i
]
=
acc
;
}
}
```
```
-
Next, we perform reduction for
`accs`
across all warps, allowing
Next, we perform reduction for
`accs`
across all warps, allowing
each thread to have the accumulation of
`accs`
for the assigned
each thread to have the accumulation of
`accs`
for the assigned
head positions of all context tokens. Please note that each
`accs`
head positions of all context tokens. Please note that each
`accs`
in every thread only stores the accumulation for a portion of
in every thread only stores the accumulation for a portion of
elements of the entire head for all context tokens. However, overall,
elements of the entire head for all context tokens. However, overall,
all results for output have been calculated but are just stored in
all results for output have been calculated but are just stored in
different thread register memory.
different thread register memory.
```
cpp
```
cpp
float
*
out_smem
=
reinterpret_cast
<
float
*>
(
shared_mem
);
float
*
out_smem
=
reinterpret_cast
<
float
*>
(
shared_mem
);
for
(
int
i
=
NUM_WARPS
;
i
>
1
;
i
/=
2
)
{
for
(
int
i
=
NUM_WARPS
;
i
>
1
;
i
/=
2
)
{
// Upper warps write to shared memory.
// Upper warps write to shared memory.
...
...
float
*
dst
=
&
out_smem
[(
warp_idx
-
mid
)
*
HEAD_SIZE
];
float
*
dst
=
&
out_smem
[(
warp_idx
-
mid
)
*
HEAD_SIZE
];
...
@@ -469,32 +467,32 @@ title: vLLM Paged Attention
...
@@ -469,32 +467,32 @@ title: vLLM Paged Attention
}
}
// Write out the accs.
// Write out the accs.
}
}
```
```
## Output
## Output
-
Now we can write all of calculated result from local register memory
Now we can write all of calculated result from local register memory
to final output global memory.
to final output global memory.
```
cpp
```
cpp
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
+
partition_idx
*
HEAD_SIZE
;
+
partition_idx
*
HEAD_SIZE
;
```
```
-
First, we need to define the
`out_ptr`
variable, which points to
First, we need to define the
`out_ptr`
variable, which points to
the start address of the assigned sequence and assigned head.
the start address of the assigned sequence and assigned head.
```
cpp
```
cpp
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
from_float
(
*
(
out_ptr
+
row_idx
),
accs
[
i
]);
from_float
(
*
(
out_ptr
+
row_idx
),
accs
[
i
]);
}
}
}
}
```
```
-
Finally, we need to iterate over different assigned head positions
Finally, we need to iterate over different assigned head positions
and write out the corresponding accumulated result based on the
and write out the corresponding accumulated result based on the
`out_ptr`
.
`out_ptr`
.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment