Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
54bd9a03
"vllm/model_executor/models/exaone.py" did not exist on "69672f116cf83dbcfd2d470a959dfe123df4d301"
Unverified
Commit
54bd9a03
authored
Aug 15, 2024
by
youkaichao
Committed by
GitHub
Aug 15, 2024
Browse files
register custom op for flash attn and use from torch.ops (#7536)
parent
50b8d08d
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
220 additions
and
41 deletions
+220
-41
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+7
-0
tests/compile/test_full_graph.py
tests/compile/test_full_graph.py
+20
-0
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+61
-12
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+129
-26
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+3
-3
No files found.
.buildkite/test-pipeline.yaml
View file @
54bd9a03
...
...
@@ -163,6 +163,13 @@ steps:
-
pytest -v -s models/test_oot_registration.py
# it needs a clean process
-
pytest -v -s models -m \"not vlm\" --ignore=models/test_oot_registration.py
-
label
:
torch compile integration test
source_file_dependencies
:
-
vllm/
commands
:
-
pytest -v -s ./compile/test_full_graph.py
-
label
:
Vision Language Models Test
# 42min
mirror_hardwares
:
[
amd
]
source_file_dependencies
:
...
...
tests/compile/test_full_graph.py
0 → 100644
View file @
54bd9a03
import
os
import
pytest
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"meta-llama/Meta-Llama-3-8B"
])
def
test_full_graph
(
model
):
# make sure these models can be captured in full graph mode
os
.
environ
[
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
]
=
"1"
from
vllm
import
LLM
,
SamplingParams
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
sampling_params
=
SamplingParams
(
temperature
=
0
)
llm
=
LLM
(
model
=
"meta-llama/Meta-Llama-3-8B"
)
llm
.
generate
(
prompts
,
sampling_params
)
tests/kernels/test_flash_attn.py
View file @
54bd9a03
...
...
@@ -2,13 +2,16 @@ from typing import List, Optional, Tuple
import
pytest
import
torch
from
vllm_flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
NUM_HEADS
=
[(
16
,
16
),
(
32
,
8
),
(
64
,
8
)]
import
vllm.attention.backends.flash_attn
# noqa: F401
NUM_HEADS
=
[(
4
,
4
),
(
8
,
2
),
(
16
,
2
)]
HEAD_SIZES
=
[
128
,
256
]
BLOCK_SIZES
=
[
16
,
32
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
NUM_BLOCKS
=
32768
# Large enough to test overflow in index calculation.
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS
=
[
32768
,
2048
]
def
ref_paged_attn
(
...
...
@@ -72,6 +75,7 @@ def ref_paged_attn(
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
torch
.
inference_mode
()
def
test_flash_attn_with_paged_kv
(
kv_lens
:
List
[
int
],
...
...
@@ -80,6 +84,7 @@ def test_flash_attn_with_paged_kv(
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
...
...
@@ -91,7 +96,7 @@ def test_flash_attn_with_paged_kv(
scale
=
head_size
**-
0.5
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
key_cache
=
torch
.
randn
(
NUM_BLOCKS
,
key_cache
=
torch
.
randn
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
,
...
...
@@ -101,14 +106,14 @@ def test_flash_attn_with_paged_kv(
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
NUM_BLOCKS
,
num_blocks
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
output
=
flash_attn_with_kvcache
(
q
=
query
.
unsqueeze
(
1
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
output
=
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
(
decode_query
=
query
.
unsqueeze
(
1
),
k
ey
_cache
=
key_cache
,
v
alue
_cache
=
value_cache
,
softmax_scale
=
scale
,
causal
=
True
,
block_table
=
block_tables
,
...
...
@@ -116,6 +121,25 @@ def test_flash_attn_with_paged_kv(
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
).
squeeze
(
1
)
if
num_blocks
<=
2048
:
test_utils
=
[
"test_faketensor"
,
"test_schema"
]
else
:
test_utils
=
[
"test_faketensor"
]
torch
.
library
.
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
,
args
=
tuple
(),
kwargs
=
dict
(
decode_query
=
query
.
unsqueeze
(
1
),
key_cache
=
key_cache
,
value_cache
=
value_cache
,
softmax_scale
=
scale
,
causal
=
True
,
block_table
=
block_tables
,
cache_seqlens
=
kv_lens_tensor
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
),
test_utils
=
test_utils
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
...
...
@@ -137,6 +161,7 @@ def test_flash_attn_with_paged_kv(
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
torch
.
inference_mode
()
def
test_varlen_with_paged_kv
(
seq_lens
:
List
[
Tuple
[
int
,
int
]],
...
...
@@ -146,6 +171,7 @@ def test_varlen_with_paged_kv(
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
...
...
@@ -166,7 +192,7 @@ def test_varlen_with_paged_kv(
num_query_heads
,
head_size
,
dtype
=
dtype
)
key_cache
=
torch
.
randn
(
NUM_BLOCKS
,
key_cache
=
torch
.
randn
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
,
...
...
@@ -181,11 +207,11 @@ def test_varlen_with_paged_kv(
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
NUM_BLOCKS
,
num_blocks
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
output
=
flash_attn_varlen_func
(
output
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
...
...
@@ -200,6 +226,29 @@ def test_varlen_with_paged_kv(
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
)
if
num_blocks
<=
2048
:
test_utils
=
[
"test_faketensor"
,
"test_schema"
]
else
:
test_utils
=
[
"test_faketensor"
]
torch
.
library
.
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_varlen_func
,
args
=
tuple
(),
kwargs
=
dict
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
cu_query_lens
,
cu_seqlens_k
=
cu_kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
,
softmax_scale
=
scale
,
causal
=
True
,
window_size
=
window_size
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
),
test_utils
=
test_utils
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
...
...
vllm/attention/backends/flash_attn.py
View file @
54bd9a03
...
...
@@ -3,7 +3,6 @@ from dataclasses import dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
vllm_flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
...
...
@@ -18,6 +17,108 @@ from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
from
vllm_flash_attn
import
flash_attn_varlen_func
as
_flash_attn_varlen_func
from
vllm_flash_attn
import
flash_attn_with_kvcache
as
_flash_attn_with_kvcache
@
torch
.
library
.
custom_op
(
"vllm::flash_attn_varlen_func"
,
mutates_args
=
[])
def
flash_attn_varlen_func
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens_q
:
torch
.
Tensor
,
cu_seqlens_k
:
torch
.
Tensor
,
max_seqlen_q
:
int
,
max_seqlen_k
:
int
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
softcap
:
float
=
0.0
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
block_table
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
# custom op does not support tuple input
real_window_size
:
Tuple
[
int
,
int
]
if
window_size
is
None
:
real_window_size
=
(
-
1
,
-
1
)
else
:
assert
len
(
window_size
)
==
2
real_window_size
=
(
window_size
[
0
],
window_size
[
1
])
return
_flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
window_size
=
real_window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
block_table
=
block_table
,
)
@
flash_attn_varlen_func
.
register_fake
# type: ignore
def
_
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens_q
:
torch
.
Tensor
,
cu_seqlens_k
:
torch
.
Tensor
,
max_seqlen_q
:
int
,
max_seqlen_k
:
int
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
softcap
:
float
=
0.0
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
block_table
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
q
)
@
torch
.
library
.
custom_op
(
"vllm::flash_attn_with_kvcache"
,
mutates_args
=
[])
def
flash_attn_with_kvcache
(
decode_query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
cache_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
block_table
:
Optional
[
torch
.
Tensor
]
=
None
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
softcap
:
float
=
0.0
,
)
->
torch
.
Tensor
:
return
_flash_attn_with_kvcache
(
decode_query
,
key_cache
,
value_cache
,
cache_seqlens
=
cache_seqlens
,
block_table
=
block_table
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
alibi_slopes
=
alibi_slopes
,
softcap
=
softcap
,
)
@
flash_attn_with_kvcache
.
register_fake
# type: ignore
def
_
(
decode_query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
cache_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
block_table
:
Optional
[
torch
.
Tensor
]
=
None
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
softcap
:
float
=
0.0
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
decode_query
)
class
FlashAttentionBackend
(
AttentionBackend
):
...
...
@@ -517,7 +618,7 @@ class FlashAttentionImpl(AttentionImpl):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
out
=
flash_attn_varlen_func
(
out
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
...
...
@@ -537,7 +638,8 @@ class FlashAttentionImpl(AttentionImpl):
# prefix-enabled attention
assert
prefill_meta
.
seq_lens
is
not
None
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
output
[:
num_prefill_tokens
]
=
flash_attn_varlen_func
(
output
[:
num_prefill_tokens
]
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
# noqa
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
...
...
@@ -554,7 +656,8 @@ class FlashAttentionImpl(AttentionImpl):
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
output
[
num_prefill_tokens
:]
=
flash_attn_with_kvcache
(
output
[
num_prefill_tokens
:]
=
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
(
decode_query
.
unsqueeze
(
1
),
key_cache
,
value_cache
,
...
...
vllm/attention/backends/flashinfer.py
View file @
54bd9a03
...
...
@@ -4,9 +4,9 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
try
:
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.prefill
import
BatchPrefillWithPagedKVCacheWrapper
from
vllm_flash_attn
import
flash_attn_varlen_func
import
vllm.attention.backends.flash_attn
# noqa
except
ImportError
:
flash_attn_varlen_func
=
None
BatchDecodeWithPagedKVCacheWrapper
=
None
BatchPrefillWithPagedKVCacheWrapper
=
None
...
...
@@ -520,7 +520,7 @@ class FlashInferImpl(AttentionImpl):
# This happens when vllm runs the profiling to
# determine the number of blocks.
if
kv_cache
is
None
:
output
=
flash_attn_varlen_func
(
output
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment