Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
3b71c385
Unverified
Commit
3b71c385
authored
Jul 18, 2023
by
OlivierDehaene
Committed by
GitHub
Jul 18, 2023
Browse files
feat(server): flash attention v2 (#624)
parent
4d38a1c4
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
173 additions
and
112 deletions
+173
-112
Dockerfile
Dockerfile
+14
-1
server/Makefile
server/Makefile
+1
-0
server/Makefile-flash-att-v2
server/Makefile-flash-att-v2
+13
-0
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+12
-42
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
...ion_server/models/custom_modeling/flash_llama_modeling.py
+2
-10
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
...tion_server/models/custom_modeling/flash_neox_modeling.py
+2
-12
server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
...ration_server/models/custom_modeling/flash_rw_modeling.py
+3
-32
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+2
-15
server/text_generation_server/utils/flash_attn.py
server/text_generation_server/utils/flash_attn.py
+124
-0
No files found.
Dockerfile
View file @
3b71c385
...
...
@@ -98,6 +98,16 @@ COPY server/Makefile-flash-att Makefile
# Build specific version of flash attention
RUN
make build-flash-attention
# Build Flash Attention v2 CUDA kernels
FROM
kernel-builder as flash-att-v2-builder
WORKDIR
/usr/src
COPY
server/Makefile-flash-att-v2 Makefile
# Build specific version of flash attention v2
RUN
make build-flash-attention-v2
# Build Transformers CUDA kernels
FROM
kernel-builder as custom-kernels-builder
...
...
@@ -146,8 +156,11 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp
COPY
--from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
COPY
--from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from flash attention v2 builder
COPY
--from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from custom kernels builder
COPY
--from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39
/custom_kernels /usr/src/custom-kernels/src/custom_kernel
s
COPY
--from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39
/opt/conda/lib/python3.9/site-package
s
# Copy builds artifacts from vllm builder
COPY
--from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
...
...
server/Makefile
View file @
3b71c385
include
Makefile-flash-att
include
Makefile-flash-att-v2
include
Makefile-vllm
unit-tests
:
...
...
server/Makefile-flash-att-v2
0 → 100644
View file @
3b71c385
flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc
flash-attention-v2:
# Clone flash attention
pip install packaging
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2
build-flash-attention-v2: flash-attention-v2
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit)
cd flash-attention-v2 && python setup.py build
install-flash-attention-v2: build-flash-attention-v2
cd flash-attention-v2 && python setup.py install
\ No newline at end of file
server/text_generation_server/models/__init__.py
View file @
3b71c385
...
...
@@ -42,51 +42,21 @@ __all__ = [
"get_model"
,
]
FLASH_ATT_ERROR_MESSAGE
=
(
"{} requires CUDA and Flash Attention kernels to be installed.
\n
"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
)
FLASH_ATT_ERROR_MESSAGE
=
"{} requires Flash Attention enabled models."
FLASH_ATTENTION
=
True
try
:
if
not
os
.
getenv
(
"USE_FLASH_ATTENTION"
,
""
).
lower
()
==
"false"
:
if
not
torch
.
cuda
.
is_available
():
FLASH_ATT_ERROR_MESSAGE
=
(
"{} requires CUDA. No compatible CUDA devices found."
)
raise
ImportError
(
"CUDA is not available"
)
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
is_sm75
=
major
==
7
and
minor
==
5
is_sm8x
=
major
==
8
and
minor
>=
0
is_sm90
=
major
==
9
and
minor
==
0
supported
=
is_sm75
or
is_sm8x
or
is_sm90
if
not
supported
:
FLASH_ATT_ERROR_MESSAGE
=
(
"{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. "
"No compatible CUDA device found."
)
raise
ImportError
(
f
"GPU with CUDA capability
{
major
}
{
minor
}
is not supported"
)
from
text_generation_server.models.flash_rw
import
FlashRWSharded
from
text_generation_server.models.flash_neox
import
FlashNeoXSharded
from
text_generation_server.models.flash_llama
import
(
FlashLlama
,
)
from
text_generation_server.models.flash_santacoder
import
(
FlashSantacoderSharded
,
)
FLASH_ATTENTION
=
True
else
:
FLASH_ATTENTION
=
False
except
ImportError
:
logger
.
opt
(
exception
=
True
).
warning
(
"Could not import Flash Attention enabled models"
from
text_generation_server.models.flash_rw
import
FlashRWSharded
from
text_generation_server.models.flash_neox
import
FlashNeoXSharded
from
text_generation_server.models.flash_llama
import
(
FlashLlama
,
)
from
text_generation_server.models.flash_santacoder
import
(
FlashSantacoderSharded
,
)
except
ImportError
as
e
:
logger
.
warning
(
f
"Could not import Flash Attention enabled models:
{
e
}
"
)
FLASH_ATTENTION
=
False
if
FLASH_ATTENTION
:
...
...
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
View file @
3b71c385
...
...
@@ -26,13 +26,13 @@ from transformers.activations import ACT2FN
from
typing
import
Optional
,
List
,
Tuple
# Flash attention imports
import
flash_attn_cuda
import
dropout_layer_norm
# vllm imports
import
vllm_cache_ops
import
vllm_attention_ops
from
text_generation_server.utils.flash_attn
import
attention
from
text_generation_server.utils.layers
import
(
TensorParallelRowLinear
,
TensorParallelColumnLinear
,
...
...
@@ -164,22 +164,14 @@ class FlashLlamaAttention(torch.nn.Module):
# Prefill
if
cu_seqlen_prefill
is
not
None
:
# flash attention
flash_attn_cuda
.
fwd
(
attention
(
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
attn_output
,
cu_seqlen_prefill
,
cu_seqlen_prefill
,
max_s
,
max_s
,
0.0
,
self
.
softmax_scale
,
False
,
True
,
False
,
0
,
None
,
)
# Decode
else
:
...
...
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
View file @
3b71c385
...
...
@@ -27,13 +27,11 @@ from transformers.modeling_utils import PreTrainedModel
from
transformers.models.gpt_neox
import
GPTNeoXConfig
from
typing
import
Optional
,
List
,
Tuple
# Flash attention imports
import
flash_attn_cuda
# vllm imports
import
vllm_cache_ops
import
vllm_attention_ops
from
text_generation_server.utils.flash_attn
import
attention
from
text_generation_server.utils.layers
import
(
TensorParallelRowLinear
,
TensorParallelColumnLinear
,
...
...
@@ -153,22 +151,14 @@ class FlashNeoxAttention(torch.nn.Module):
# Prefill
if
cu_seqlen_prefill
is
not
None
:
# flash attention
flash_attn_cuda
.
fwd
(
attention
(
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
attn_output
,
cu_seqlen_prefill
,
cu_seqlen_prefill
,
max_s
,
max_s
,
0.0
,
self
.
softmax_scale
,
False
,
True
,
False
,
0
,
None
,
)
# Decode
else
:
...
...
server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
View file @
3b71c385
...
...
@@ -6,13 +6,11 @@ from transformers.modeling_utils import PreTrainedModel
from
transformers.configuration_utils
import
PretrainedConfig
from
typing
import
Optional
,
List
,
Tuple
# Flash attention imports
import
flash_attn_cuda
# vllm imports
import
vllm_cache_ops
import
vllm_attention_ops
from
text_generation_server.utils.flash_attn
import
attention
from
text_generation_server.utils.layers
import
(
TensorParallelRowLinear
,
TensorParallelColumnLinear
,
...
...
@@ -182,27 +180,15 @@ class FlashRWAttention(torch.nn.Module):
# Prefill
if
cu_seqlen_prefill
is
not
None
:
if
self
.
num_heads_kv
==
1
:
# Expand to query shape
kv
=
kv
.
expand
(
-
1
,
2
,
self
.
num_heads
,
self
.
head_size
)
# flash attention
flash_attn_cuda
.
fwd
(
attention
(
query
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
),
torch
.
select
(
kv
,
dim
=
1
,
index
=
1
),
attn_output
,
cu_seqlen_prefill
,
cu_seqlen_prefill
,
max_s
,
max_s
,
0.0
,
self
.
softmax_scale
,
False
,
True
,
False
,
0
,
None
,
)
# Decode
else
:
...
...
@@ -314,30 +300,15 @@ class FlashRWLargeAttention(torch.nn.Module):
# Prefill
if
cu_seqlen_prefill
is
not
None
:
# Expand to query shape
kv
=
(
kv
.
unsqueeze
(
2
)
.
expand
(
-
1
,
self
.
num_groups
,
self
.
num_heads
,
2
,
self
.
head_size
)
.
reshape
(
-
1
,
self
.
num_groups
*
self
.
num_heads
,
2
,
self
.
head_size
)
)
# flash attention
flash_attn_cuda
.
fwd
(
attention
(
query
,
torch
.
select
(
kv
,
dim
=
2
,
index
=
0
),
torch
.
select
(
kv
,
dim
=
2
,
index
=
1
),
attn_output
,
cu_seqlen_prefill
,
cu_seqlen_prefill
,
max_s
,
max_s
,
0.0
,
self
.
softmax_scale
,
False
,
True
,
False
,
0
,
None
,
)
# Decode
else
:
...
...
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
View file @
3b71c385
...
...
@@ -5,13 +5,11 @@ from torch import nn
from
transformers.activations
import
ACT2FN
from
typing
import
Optional
,
List
,
Tuple
# Flash attention imports
import
flash_attn_cuda
# vllm imports
import
vllm_cache_ops
import
vllm_attention_ops
from
text_generation_server.utils.flash_attn
import
attention
from
text_generation_server.utils.layers
import
(
TensorParallelRowLinear
,
TensorParallelColumnLinear
,
...
...
@@ -271,26 +269,15 @@ class FlashMQAttention(torch.nn.Module):
# Prefill
if
cu_seqlen_prefill
is
not
None
:
# Expand from 1 to num_heads
key_value
=
key_value
.
expand
(
-
1
,
2
,
self
.
num_heads
,
self
.
head_size
)
# flash attention
flash_attn_cuda
.
fwd
(
attention
(
query
,
torch
.
select
(
key_value
,
dim
=
1
,
index
=
0
),
torch
.
select
(
key_value
,
dim
=
1
,
index
=
1
),
attn_output
,
cu_seqlen_prefill
,
cu_seqlen_prefill
,
max_s
,
max_s
,
0.0
,
self
.
softmax_scale
,
False
,
True
,
False
,
0
,
None
,
)
# Decode
else
:
...
...
server/text_generation_server/utils/flash_attn.py
0 → 100644
View file @
3b71c385
import
os
import
torch
from
loguru
import
logger
if
os
.
getenv
(
"USE_FLASH_ATTENTION"
,
""
).
lower
()
==
"false"
:
raise
ImportError
(
"`USE_FLASH_ATTENTION` is false."
)
if
not
torch
.
cuda
.
is_available
():
raise
ImportError
(
"CUDA is not available"
)
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
is_sm75
=
major
==
7
and
minor
==
5
is_sm8x
=
major
==
8
and
minor
>=
0
is_sm90
=
major
==
9
and
minor
==
0
HAS_FLASH_ATTN
=
False
HAS_FLASH_ATTN_V2
=
False
try
:
try
:
import
flash_attn_2_cuda
except
ImportError
:
raise
ImportError
(
"Flash Attention V2 is not installed.
\n
"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention v2 with `cd server && make install install-flash-attention-v2`"
)
if
not
(
is_sm8x
or
is_sm90
):
raise
ImportError
(
f
"GPU with CUDA capability
{
major
}
{
minor
}
is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2
=
True
except
ImportError
as
e
:
try
:
import
flash_attn_cuda
except
ImportError
:
raise
ImportError
(
"Flash Attention is not installed.
\n
"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
)
from
e
if
not
(
is_sm75
or
is_sm8x
or
is_sm90
):
raise
ImportError
(
f
"GPU with CUDA capability
{
major
}
{
minor
}
is not supported"
)
from
e
logger
.
warning
(
f
"Unable to use Flash Attention V2:
{
e
}
"
)
HAS_FLASH_ATTN
=
True
def
attention
(
q
,
k
,
v
,
out
,
cu_seqlens
,
max_s
,
softmax_scale
,
):
if
HAS_FLASH_ATTN_V2
:
return
flash_attn_2_cuda
.
varlen_fwd
(
q
,
k
,
v
,
out
,
cu_seqlens
,
cu_seqlens
,
max_s
,
max_s
,
0.0
,
softmax_scale
,
False
,
True
,
False
,
None
,
)
if
HAS_FLASH_ATTN
:
# Flash attention v1 requires q, k and v to have the same number of heads
if
k
.
shape
[
1
]
!=
q
.
shape
[
1
]:
# MQA expand
if
k
.
shape
[
1
]
==
1
:
k
=
k
.
expand
(
-
1
,
q
.
shape
[
1
],
-
1
)
# Grouped attention reshape
else
:
original_shape
=
k
.
shape
k
=
(
k
.
unsqueeze
(
2
)
.
expand
(
-
1
,
-
1
,
q
.
shape
[
1
]
//
k
.
shape
[
1
],
-
1
)
.
reshape
(
original_shape
[
0
],
-
1
,
original_shape
[
2
])
)
if
v
.
shape
[
1
]
!=
q
.
shape
[
1
]:
# MQA expand
if
v
.
shape
[
1
]
==
1
:
v
=
v
.
expand
(
-
1
,
q
.
shape
[
1
],
-
1
)
# Grouped attention reshape
else
:
original_shape
=
v
.
shape
v
=
(
v
.
unsqueeze
(
2
)
.
expand
(
-
1
,
-
1
,
q
.
shape
[
1
]
//
v
.
shape
[
1
],
-
1
)
.
reshape
(
original_shape
[
0
],
-
1
,
original_shape
[
2
])
)
return
flash_attn_cuda
.
fwd
(
q
,
k
,
v
,
out
,
cu_seqlens
,
cu_seqlens
,
max_s
,
max_s
,
0.0
,
softmax_scale
,
False
,
True
,
False
,
0
,
None
,
)
raise
NotImplementedError
(
"flash attention is not installed"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment