Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
12590fdc
Unverified
Commit
12590fdc
authored
Oct 23, 2023
by
OlivierDehaene
Committed by
GitHub
Oct 23, 2023
Browse files
feat: paged attention v2 (#1183)
parent
63fa5346
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
126 additions
and
61 deletions
+126
-61
server/Makefile-flash-att-v2
server/Makefile-flash-att-v2
+1
-1
server/Makefile-vllm
server/Makefile-vllm
+2
-2
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
...ion_server/models/custom_modeling/flash_llama_modeling.py
+4
-11
server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py
...n_server/models/custom_modeling/flash_mistral_modeling.py
+4
-10
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
...tion_server/models/custom_modeling/flash_neox_modeling.py
+4
-10
server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
...ration_server/models/custom_modeling/flash_rw_modeling.py
+7
-16
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+4
-11
server/text_generation_server/utils/paged_attention.py
server/text_generation_server/utils/paged_attention.py
+100
-0
No files found.
server/Makefile-flash-att-v2
View file @
12590fdc
flash_att_v2_commit :=
601b4dc48dbe9d87c468daa2b4c0c8388b83753c
flash_att_v2_commit :=
02ac572f3ffc4f402e4183aaa6824b45859d3ed3
flash-attention-v2:
# Clone flash attention
...
...
server/Makefile-vllm
View file @
12590fdc
vllm_commit :=
25dbff97d5a8f2ba331847237b458b2692e9ae78
vllm_commit :=
f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
vllm:
# Clone vllm
git clone https://github.com/
OlivierDehaene
/vllm.git
git clone https://github.com/
vllm-project
/vllm.git
build-vllm: vllm
cd vllm && git fetch && git checkout $(vllm_commit)
...
...
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
View file @
12590fdc
...
...
@@ -29,11 +29,7 @@ from typing import Optional, List, Tuple
# Flash attention imports
import
dropout_layer_norm
# vllm imports
import
vllm_cache_ops
import
vllm_attention_ops
from
text_generation_server.utils.flash_attn
import
attention
from
text_generation_server.utils
import
paged_attention
,
flash_attn
from
text_generation_server.utils.layers
import
(
TensorParallelRowLinear
,
TensorParallelColumnLinear
,
...
...
@@ -269,7 +265,7 @@ class FlashLlamaAttention(torch.nn.Module):
self
.
rotary_emb
(
query
,
cos
,
sin
)
self
.
rotary_emb
(
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
),
cos
,
sin
)
vllm_cache_ops
.
reshape_and_cache
(
paged_attention
.
reshape_and_cache
(
kv
[:,
0
],
kv
[:,
1
],
kv_cache
[
0
],
kv_cache
[
1
],
slots
)
...
...
@@ -279,7 +275,7 @@ class FlashLlamaAttention(torch.nn.Module):
# Prefill
if
cu_seqlen_prefill
is
not
None
:
# flash attention
attention
(
flash_attn
.
attention
(
query
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
),
torch
.
select
(
kv
,
dim
=
1
,
index
=
1
),
...
...
@@ -290,9 +286,7 @@ class FlashLlamaAttention(torch.nn.Module):
)
# Decode
else
:
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
block_size
=
kv_cache
[
1
].
shape
[
3
]
vllm_attention_ops
.
single_query_cached_kv_attention
(
paged_attention
.
attention
(
attn_output
,
query
,
kv_cache
[
0
],
...
...
@@ -301,7 +295,6 @@ class FlashLlamaAttention(torch.nn.Module):
self
.
softmax_scale
,
block_tables
,
input_lengths
,
block_size
,
max_s
,
)
...
...
server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py
View file @
12590fdc
...
...
@@ -29,10 +29,7 @@ from typing import Optional, List, Tuple
# Flash attention imports
import
dropout_layer_norm
# vllm imports
import
vllm_cache_ops
import
vllm_attention_ops
from
text_generation_server.utils
import
paged_attention
,
flash_attn
from
text_generation_server.utils.flash_attn
import
attention
,
HAS_FLASH_ATTN_V2
from
text_generation_server.utils.layers
import
(
TensorParallelRowLinear
,
...
...
@@ -272,7 +269,7 @@ class MistralAttention(torch.nn.Module):
else
:
kv_to_cache
=
kv
vllm_cache_ops
.
reshape_and_cache
(
paged_attention
.
reshape_and_cache
(
kv_to_cache
[:,
0
],
kv_to_cache
[:,
1
],
kv_cache
[
0
],
kv_cache
[
1
],
slots
)
...
...
@@ -282,7 +279,7 @@ class MistralAttention(torch.nn.Module):
# Prefill
if
cu_seqlen_prefill
is
not
None
:
# flash attention
attention
(
flash_attn
.
attention
(
query
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
),
torch
.
select
(
kv
,
dim
=
1
,
index
=
1
),
...
...
@@ -294,9 +291,7 @@ class MistralAttention(torch.nn.Module):
)
# Decode
else
:
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
block_size
=
kv_cache
[
1
].
shape
[
3
]
vllm_attention_ops
.
single_query_cached_kv_attention
(
paged_attention
.
attention
(
attn_output
,
query
,
kv_cache
[
0
],
...
...
@@ -305,7 +300,6 @@ class MistralAttention(torch.nn.Module):
self
.
softmax_scale
,
block_tables
,
input_lengths
,
block_size
,
max_s
,
)
...
...
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
View file @
12590fdc
...
...
@@ -27,10 +27,7 @@ from transformers.modeling_utils import PreTrainedModel
from
transformers.models.gpt_neox
import
GPTNeoXConfig
from
typing
import
Optional
,
List
,
Tuple
# vllm imports
import
vllm_cache_ops
import
vllm_attention_ops
from
text_generation_server.utils
import
paged_attention
,
flash_attn
from
text_generation_server.utils.flash_attn
import
attention
from
text_generation_server.utils.layers
import
(
TensorParallelRowLinear
,
...
...
@@ -141,7 +138,7 @@ class FlashNeoxAttention(torch.nn.Module):
self
.
rotary_emb
(
qkv
[:,
0
],
cos
,
sin
)
self
.
rotary_emb
(
qkv
[:,
1
],
cos
,
sin
)
vllm_cache_ops
.
reshape_and_cache
(
paged_attention
.
reshape_and_cache
(
qkv
[:,
1
],
qkv
[:,
2
],
kv_cache
[
0
],
kv_cache
[
1
],
slots
)
...
...
@@ -151,7 +148,7 @@ class FlashNeoxAttention(torch.nn.Module):
# Prefill
if
cu_seqlen_prefill
is
not
None
:
# flash attention
attention
(
flash_attn
.
attention
(
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
...
...
@@ -162,9 +159,7 @@ class FlashNeoxAttention(torch.nn.Module):
)
# Decode
else
:
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
block_size
=
kv_cache
[
1
].
shape
[
3
]
vllm_attention_ops
.
single_query_cached_kv_attention
(
paged_attention
.
attention
(
attn_output
,
qkv
[:,
0
],
kv_cache
[
0
],
...
...
@@ -173,7 +168,6 @@ class FlashNeoxAttention(torch.nn.Module):
self
.
softmax_scale
,
block_tables
,
input_lengths
,
block_size
,
max_s
,
)
...
...
server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
View file @
12590fdc
...
...
@@ -6,10 +6,7 @@ from transformers.modeling_utils import PreTrainedModel
from
transformers.configuration_utils
import
PretrainedConfig
from
typing
import
Optional
,
List
,
Tuple
# vllm imports
import
vllm_cache_ops
import
vllm_attention_ops
from
text_generation_server.utils
import
paged_attention
,
flash_attn
from
text_generation_server.utils.flash_attn
import
attention
from
text_generation_server.utils.layers
import
(
TensorParallelRowLinear
,
...
...
@@ -191,7 +188,7 @@ class FlashRWAttention(torch.nn.Module):
self
.
rotary_emb
(
query
,
cos
,
sin
)
self
.
rotary_emb
(
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
),
cos
,
sin
)
vllm_cache_ops
.
reshape_and_cache
(
paged_attention
.
reshape_and_cache
(
kv
[:,
0
],
kv
[:,
1
],
kv_cache
[
0
],
kv_cache
[
1
],
slots
)
...
...
@@ -201,7 +198,7 @@ class FlashRWAttention(torch.nn.Module):
# Prefill
if
cu_seqlen_prefill
is
not
None
:
# flash attention
attention
(
flash_attn
.
attention
(
query
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
),
torch
.
select
(
kv
,
dim
=
1
,
index
=
1
),
...
...
@@ -212,9 +209,7 @@ class FlashRWAttention(torch.nn.Module):
)
# Decode
else
:
# kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size]
block_size
=
kv_cache
[
1
].
shape
[
3
]
vllm_attention_ops
.
single_query_cached_kv_attention
(
paged_attention
.
attention
(
attn_output
,
query
,
kv_cache
[
0
],
...
...
@@ -223,7 +218,6 @@ class FlashRWAttention(torch.nn.Module):
self
.
softmax_scale
,
block_tables
,
input_lengths
,
block_size
,
max_s
,
)
...
...
@@ -310,7 +304,7 @@ class FlashRWLargeAttention(torch.nn.Module):
self
.
rotary_emb
(
query
,
cos
,
sin
)
self
.
rotary_emb
(
torch
.
select
(
kv
,
dim
=
2
,
index
=
0
),
cos
,
sin
)
vllm_cache_ops
.
reshape_and_cache
(
paged_attention
.
reshape_and_cache
(
kv
[:,
:,
0
].
contiguous
(),
kv
[:,
:,
1
].
contiguous
(),
kv_cache
[
0
],
...
...
@@ -324,7 +318,7 @@ class FlashRWLargeAttention(torch.nn.Module):
# Prefill
if
cu_seqlen_prefill
is
not
None
:
# flash attention
attention
(
flash_attn
.
attention
(
query
,
torch
.
select
(
kv
,
dim
=
2
,
index
=
0
),
torch
.
select
(
kv
,
dim
=
2
,
index
=
1
),
...
...
@@ -335,9 +329,7 @@ class FlashRWLargeAttention(torch.nn.Module):
)
# Decode
else
:
# kv_cache[1] => [num_blocks, num_groups, head_size, block_size]
block_size
=
kv_cache
[
1
].
shape
[
3
]
vllm_attention_ops
.
single_query_cached_kv_attention
(
paged_attention
.
attention
(
attn_output
,
query
,
kv_cache
[
0
],
...
...
@@ -346,7 +338,6 @@ class FlashRWLargeAttention(torch.nn.Module):
self
.
softmax_scale
,
block_tables
,
input_lengths
,
block_size
,
max_s
,
)
...
...
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
View file @
12590fdc
...
...
@@ -5,10 +5,7 @@ from torch import nn
from
transformers.activations
import
ACT2FN
from
typing
import
Optional
,
List
,
Tuple
# vllm imports
import
vllm_cache_ops
import
vllm_attention_ops
from
text_generation_server.utils
import
paged_attention
,
flash_attn
from
text_generation_server.utils.flash_attn
import
attention
from
text_generation_server.utils.layers
import
(
TensorParallelRowLinear
,
...
...
@@ -18,7 +15,6 @@ from text_generation_server.utils.layers import (
FastLayerNorm
,
get_linear
,
)
from
safetensors
import
SafetensorError
def
load_multi_mqa
(
...
...
@@ -258,7 +254,7 @@ class FlashMQAttention(torch.nn.Module):
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key_value
=
key_value
.
view
(
-
1
,
2
,
1
,
self
.
head_size
)
vllm_cache_ops
.
reshape_and_cache
(
paged_attention
.
reshape_and_cache
(
key_value
[:,
0
],
key_value
[:,
1
],
kv_cache
[
0
],
kv_cache
[
1
],
slots
)
...
...
@@ -268,7 +264,7 @@ class FlashMQAttention(torch.nn.Module):
# Prefill
if
cu_seqlen_prefill
is
not
None
:
# flash attention
attention
(
flash_attn
.
attention
(
query
,
torch
.
select
(
key_value
,
dim
=
1
,
index
=
0
),
torch
.
select
(
key_value
,
dim
=
1
,
index
=
1
),
...
...
@@ -279,9 +275,7 @@ class FlashMQAttention(torch.nn.Module):
)
# Decode
else
:
# kv_cache[1] => [num_blocks, 1, head_size, block_size]
block_size
=
kv_cache
[
1
].
shape
[
3
]
vllm_attention_ops
.
single_query_cached_kv_attention
(
paged_attention
.
attention
(
attn_output
,
query
,
kv_cache
[
0
],
...
...
@@ -290,7 +284,6 @@ class FlashMQAttention(torch.nn.Module):
self
.
softmax_scale
,
block_tables
,
input_lengths
,
block_size
,
max_s
,
)
...
...
server/text_generation_server/utils/paged_attention.py
0 → 100644
View file @
12590fdc
import
torch
# vllm imports
from
vllm
import
cache_ops
from
vllm
import
attention_ops
_PARTITION_SIZE
=
512
def
reshape_and_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
):
cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slots
)
def
attention
(
out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
kv_head_mapping
:
torch
.
Tensor
,
softmax_scale
:
float
,
block_tables
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# value_cache => [num_blocks, num_heads, head_size, block_size]
block_size
=
value_cache
.
shape
[
3
]
num_seqs
,
num_heads
,
head_size
=
query
.
shape
max_num_partitions
=
(
(
max_s
+
_PARTITION_SIZE
-
1
)
//
_PARTITION_SIZE
)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
use_v1
=
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
if
use_v1
:
attention_ops
.
paged_attention_v1
(
out
,
query
,
key_cache
,
value_cache
,
kv_head_mapping
,
softmax_scale
,
block_tables
,
input_lengths
,
block_size
,
max_s
,
None
,
)
else
:
# Run PagedAttention V2.
assert
_PARTITION_SIZE
%
block_size
==
0
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
,
head_size
),
dtype
=
out
.
dtype
,
device
=
out
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
),
dtype
=
torch
.
float32
,
device
=
out
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
attention_ops
.
paged_attention_v2
(
out
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
kv_head_mapping
,
softmax_scale
,
block_tables
,
input_lengths
,
block_size
,
max_s
,
None
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment