Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fc7db442
Commit
fc7db442
authored
Nov 19, 2024
by
zhuwenwen
Browse files
update fa interface tests
parent
aa389394
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
158 additions
and
122 deletions
+158
-122
tests/entrypoints/openai/test_oot_registration.py
tests/entrypoints/openai/test_oot_registration.py
+28
-12
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+130
-110
No files found.
tests/entrypoints/openai/test_oot_registration.py
View file @
fc7db442
from
...utils
import
VLLM_PATH
,
RemoteOpenAIServer
import
vllm.envs
as
envs
chatml_jinja_path
=
VLLM_PATH
/
"examples/template_chatml.jinja"
assert
chatml_jinja_path
.
exists
()
...
...
@@ -6,11 +7,25 @@ assert chatml_jinja_path.exists()
def
run_and_test_dummy_opt_api_server
(
model
,
tp
=
1
):
# the model is registered through the plugin
if
envs
.
VLLM_USE_TRITON_FLASH_ATTN
:
server_args
=
[
"--gpu-memory-utilization"
,
"0.10"
,
"--dtype"
,
"float16"
,
# "float32",
"float32"
,
"--chat-template"
,
str
(
chatml_jinja_path
),
"--load-format"
,
"dummy"
,
"-tp"
,
f
"
{
tp
}
"
,
]
else
:
server_args
=
[
"--gpu-memory-utilization"
,
"0.10"
,
"--dtype"
,
"float16"
,
"--chat-template"
,
str
(
chatml_jinja_path
),
"--load-format"
,
...
...
@@ -33,10 +48,11 @@ def run_and_test_dummy_opt_api_server(model, tp=1):
)
generated_text
=
completion
.
choices
[
0
].
message
.
content
assert
generated_text
is
not
None
# make sure only the first token is generated
vim
# make sure only the first token is generated
rest
=
generated_text
.
replace
(
"<s>"
,
""
)
assert
rest
==
""
def
test_oot_registration_for_api_server
(
dummy_opt_path
:
str
):
dummy_opt_path
=
"facebook/opt-125m"
run_and_test_dummy_opt_api_server
(
dummy_opt_path
)
tests/kernels/test_flash_attn.py
View file @
fc7db442
...
...
@@ -3,7 +3,11 @@ from typing import List, Optional, Tuple
import
pytest
import
torch
import
vllm.attention.backends.flash_attn
# noqa: F401
from
vllm.utils
import
is_hip
if
is_hip
():
import
flash_attn
else
:
import
vllm.attention.backends.flash_attn
# noqa: F401
from
tests.kernels.utils
import
opcheck
from
vllm.utils
import
seed_everything
...
...
@@ -70,16 +74,16 @@ def ref_paged_attn(
return
torch
.
cat
(
outputs
,
dim
=
0
)
@
pytest
.
mark
.
parametrize
(
"kv_lens"
,
[[
1328
,
18
,
463
],
[
1
,
54
,
293
,
70
]])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
torch
.
inference_mode
()
def
test_flash_attn_with_paged_kv
(
if
not
is_hip
():
@
pytest
.
mark
.
parametrize
(
"kv_lens"
,
[[
1328
,
18
,
463
],
[
1
,
54
,
293
,
70
]])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
torch
.
inference_mode
()
def
test_flash_attn_with_paged_kv
(
kv_lens
:
List
[
int
],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
...
...
@@ -87,7 +91,7 @@ def test_flash_attn_with_paged_kv(
block_size
:
int
,
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
)
->
None
:
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
seed_everything
(
0
)
num_seqs
=
len
(
kv_lens
)
...
...
@@ -212,7 +216,22 @@ def test_varlen_with_paged_kv(
num_blocks
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
if
is_hip
():
output
=
flash_attn
.
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
cu_query_lens
,
cu_seqlens_k
=
cu_kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
,
softmax_scale
=
scale
,
causal
=
True
,
window_size
=
window_size
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
)
else
:
output
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
...
...
@@ -233,6 +252,7 @@ def test_varlen_with_paged_kv(
else
:
test_utils
=
[
"test_faketensor"
]
if
not
is_hip
():
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_varlen_func
,
args
=
tuple
(),
kwargs
=
dict
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment