Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
36729bac
"vscode:/vscode.git/clone" did not exist on "c4774eb8418864390341d35103aa747fc411b59c"
Unverified
Commit
36729bac
authored
Apr 13, 2024
by
SangBin Cho
Committed by
GitHub
Apr 12, 2024
Browse files
[Test] Test multiple attn backend for chunked prefill. (#4023)
parent
7fd3949a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
13 additions
and
23 deletions
+13
-23
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+7
-1
tests/basic_correctness/test_basic_correctness.py
tests/basic_correctness/test_basic_correctness.py
+0
-6
tests/basic_correctness/test_chunked_prefill.py
tests/basic_correctness/test_chunked_prefill.py
+0
-4
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+6
-12
No files found.
.buildkite/test-pipeline.yaml
View file @
36729bac
...
@@ -12,7 +12,13 @@ steps:
...
@@ -12,7 +12,13 @@ steps:
command
:
pytest -v -s async_engine
command
:
pytest -v -s async_engine
-
label
:
Basic Correctness Test
-
label
:
Basic Correctness Test
command
:
pytest -v -s basic_correctness
commands
:
-
VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
-
VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
-
VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_basic_correctness.py
-
VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
-
VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
-
VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_chunked_prefill.py
-
label
:
Core Test
-
label
:
Core Test
command
:
pytest -v -s core
command
:
pytest -v -s core
...
...
tests/basic_correctness/test_basic_correctness.py
View file @
36729bac
...
@@ -4,8 +4,6 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`.
...
@@ -4,8 +4,6 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`.
"""
"""
import
pytest
import
pytest
from
vllm.attention.selector
import
VLLM_ATTENTION_BACKEND
MODELS
=
[
MODELS
=
[
"facebook/opt-125m"
,
"facebook/opt-125m"
,
"meta-llama/Llama-2-7b-hf"
,
"meta-llama/Llama-2-7b-hf"
,
...
@@ -16,7 +14,6 @@ MODELS = [
...
@@ -16,7 +14,6 @@ MODELS = [
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
[
"XFORMERS"
,
"FLASH_ATTN"
])
def
test_models
(
def
test_models
(
hf_runner
,
hf_runner
,
vllm_runner
,
vllm_runner
,
...
@@ -25,10 +22,7 @@ def test_models(
...
@@ -25,10 +22,7 @@ def test_models(
dtype
:
str
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
enforce_eager
:
bool
,
enforce_eager
:
bool
,
attn_backend
:
str
,
monkeypatch
,
)
->
None
:
)
->
None
:
monkeypatch
.
setenv
(
VLLM_ATTENTION_BACKEND
,
attn_backend
)
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
del
hf_model
del
hf_model
...
...
tests/basic_correctness/test_chunked_prefill.py
View file @
36729bac
...
@@ -33,10 +33,6 @@ def test_models(
...
@@ -33,10 +33,6 @@ def test_models(
enforce_eager
:
bool
,
enforce_eager
:
bool
,
tensor_parallel_size
:
int
,
tensor_parallel_size
:
int
,
)
->
None
:
)
->
None
:
if
(
tensor_parallel_size
==
2
and
chunked_prefill_token_size
!=
16
and
not
enforce_eager
):
pytest
.
skip
(
f
"Skip
{
chunked_prefill_token_size
=
}
and
{
enforce_eager
=
}
"
"for high TP to save testing time."
)
max_num_seqs
=
min
(
chunked_prefill_token_size
,
256
)
max_num_seqs
=
min
(
chunked_prefill_token_size
,
256
)
enable_chunked_prefill
=
False
enable_chunked_prefill
=
False
max_num_batched_tokens
=
None
max_num_batched_tokens
=
None
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
36729bac
...
@@ -162,7 +162,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -162,7 +162,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# AMD Radeon 7900 series (gfx1100) currently does not support
# AMD Radeon 7900 series (gfx1100) currently does not support
# xFormers nor FlashAttention. As a temporary workaround, we use
# xFormers nor FlashAttention. As a temporary workaround, we use
# naive PyTorch implementation of attention.
# naive PyTorch implementation of attention.
self
.
attn_fuc
=
_naive_attention
()
self
.
attn_fuc
=
_naive_attention
logger
.
debug
(
"Using naive attention in ROCmBackend"
)
logger
.
debug
(
"Using naive attention in ROCmBackend"
)
elif
self
.
use_triton_flash_attn
:
elif
self
.
use_triton_flash_attn
:
from
vllm.attention.ops.triton_flash_attention
import
(
# noqa: F401
from
vllm.attention.ops.triton_flash_attention
import
(
# noqa: F401
...
@@ -334,26 +334,21 @@ def _naive_attention(
...
@@ -334,26 +334,21 @@ def _naive_attention(
prompt_lens
:
List
[
int
],
prompt_lens
:
List
[
int
],
scale
:
float
,
scale
:
float
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_tokens
=
query
.
shape
[
0
]
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
query
)
start
=
0
start
=
0
for
_
,
prompt_len
in
enumerate
(
prompt_lens
):
for
_
,
prompt_len
in
enumerate
(
prompt_lens
):
end
=
start
+
prompt_len
end
=
start
+
prompt_len
out
=
_naive_masked_attention
(
out
=
_naive_masked_attention
(
query
[
None
,
start
:
end
],
query
[
start
:
end
],
key
[
None
,
start
:
end
],
key
[
start
:
end
],
value
[
None
,
start
:
end
],
value
[
start
:
end
],
scale
,
scale
,
)
)
# TODO(woosuk): Unnecessary copy. Optimize.
# TODO(woosuk): Unnecessary copy. Optimize.
output
[
start
:
end
].
copy_
(
out
)
output
[
start
:
end
].
copy_
(
out
)
start
+=
prompt_len
start
+=
prompt_len
# Using view got RuntimeError: view size is not compatible
return
output
# with input tensor's size and stride (at least one
# dimension spans across two contiguous subspaces).
# Use reshape instead.
return
output
.
reshape
(
num_tokens
,
-
1
)
def
_naive_masked_attention
(
def
_naive_masked_attention
(
...
@@ -362,14 +357,13 @@ def _naive_masked_attention(
...
@@ -362,14 +357,13 @@ def _naive_masked_attention(
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
scale
:
float
,
scale
:
float
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
seq_len
,
_
,
_
=
query
.
shape
seq_len
,
head_size
,
head_dim
=
query
.
shape
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
seq_len
,
dtype
=
query
.
dtype
,
dtype
=
query
.
dtype
,
device
=
query
.
device
),
device
=
query
.
device
),
diagonal
=
1
)
diagonal
=
1
)
attn_mask
=
attn_mask
*
torch
.
finfo
(
query
.
dtype
).
min
attn_mask
=
attn_mask
*
torch
.
finfo
(
query
.
dtype
).
min
attn_weights
=
scale
*
torch
.
einsum
(
"qhd,khd->hqk"
,
query
,
key
).
float
()
attn_weights
=
scale
*
torch
.
einsum
(
"qhd,khd->hqk"
,
query
,
key
).
float
()
attn_weights
=
attn_weights
+
attn_mask
.
float
()
attn_weights
=
attn_weights
+
attn_mask
.
float
()
attn_weights
=
torch
.
softmax
(
attn_weights
,
dim
=-
1
).
to
(
value
.
dtype
)
attn_weights
=
torch
.
softmax
(
attn_weights
,
dim
=-
1
).
to
(
value
.
dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment