Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
622f8abf
Unverified
Commit
622f8abf
authored
Aug 30, 2024
by
Pavani Majety
Committed by
GitHub
Aug 30, 2024
Browse files
[Bugfix] bugfix and add model test for flashinfer fp8 kv cache. (#8013)
parent
1248e850
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
109 additions
and
5 deletions
+109
-5
tests/models/test_fp8kv_flashinfer.py
tests/models/test_fp8kv_flashinfer.py
+96
-0
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+13
-5
No files found.
tests/models/test_fp8kv_flashinfer.py
0 → 100644
View file @
622f8abf
# flake8: noqa
"""Tests fp8 models against ground truth generation
This verifies the flashinfer backend with fp8
quantization and fp8 KV Cache without scaling
factors Note: these tests will only pass on H100 GPU.
"""
import
os
from
typing
import
List
import
pytest
from
transformers
import
AutoTokenizer
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm
import
LLM
,
SamplingParams
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"true"
MAX_MODEL_LEN
=
1024
MODELS
=
[
"nm-testing/Meta-Llama-3-8B-Instruct-FP8"
,
]
EXPECTED_STRS_MAP
=
{
"nm-testing/Meta-Llama-3-8B-Instruct-FP8"
:
{
"auto"
:
[
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models ('
,
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to '
,
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.'
,
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne'
,
'In the sterile, metallic halls of the robotics lab, a peculiar phenomenon occurred. Zeta-5'
,
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The'
,
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of'
,
'Here are the translations:
\n\n
**Japanese:** (Haya aki no tori, mushi o'
,
],
"fp8"
:
[
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained'
,
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to '
,
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.'
,
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne'
,
'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep'
,
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here'
,
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of'
,
'Here are the translations:
\n\n
**Japanese:** (Haya aki no tori, guri o'
,
]
}
}
# This test compares against golden strings for exact match since
# there is no baseline implementation to compare against
# and is unstable w.r.t specifics of the fp8 implementation or
# the hardware being run on.
# No assert to prevent it from breaking the build
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
reason
=
"fp8 is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
[
"auto"
,
"fp8"
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"XFORMERS"
,
"FLASHINFER"
])
def
test_models
(
example_prompts
,
model_name
,
kv_cache_dtype
,
backend
)
->
None
:
# Note that the golden strings may not work for FLASHINFER Backend.
# The intention is to test the path
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
backend
model
=
LLM
(
model
=
model_name
,
max_model_len
=
MAX_MODEL_LEN
,
trust_remote_code
=
True
,
quantization
=
"fp8"
,
kv_cache_dtype
=
kv_cache_dtype
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
formatted_prompts
=
[
tokenizer
.
apply_chat_template
([{
"role"
:
"user"
,
"content"
:
prompt
}],
tokenize
=
False
,
add_generation_prompt
=
True
)
for
prompt
in
example_prompts
]
params
=
SamplingParams
(
max_tokens
=
20
,
temperature
=
0
)
generations
:
List
[
str
]
=
[]
# Note: these need to be run 1 at a time due to numerical precision,
# since the expected strs were generated this way.
for
prompt
in
formatted_prompts
:
outputs
=
model
.
generate
(
prompt
,
params
)
generations
.
append
(
outputs
[
0
].
outputs
[
0
].
text
)
del
model
print
(
f
"Testing:
{
model_name
}
with kv_cache_dtype:
{
kv_cache_dtype
}
"
)
expected_strs
=
EXPECTED_STRS_MAP
[
model_name
][
kv_cache_dtype
]
for
i
in
range
(
len
(
example_prompts
)):
generated_str
=
generations
[
i
]
expected_str
=
expected_strs
[
i
]
print
(
f
"generated_str
\n
:
{
generated_str
}
"
)
print
(
f
"expected_str
\n
:
{
expected_str
}
"
)
vllm/attention/backends/flashinfer.py
View file @
622f8abf
...
@@ -186,9 +186,13 @@ class FlashInferState(AttentionState):
...
@@ -186,9 +186,13 @@ class FlashInferState(AttentionState):
self
.
_graph_decode_workspace_buffer
,
_indptr_buffer
,
self
.
_graph_decode_workspace_buffer
,
_indptr_buffer
,
self
.
_graph_indices_buffer
,
_last_page_len_buffer
,
"NHD"
,
self
.
_graph_indices_buffer
,
_last_page_len_buffer
,
"NHD"
,
use_tensor_cores
)
use_tensor_cores
)
if
self
.
runner
.
kv_cache_dtype
.
startswith
(
"fp8"
):
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
self
.
runner
.
kv_cache_dtype
)
else
:
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
self
.
runner
.
kv_cache_dtype
,
self
.
runner
.
model_config
.
dtype
)
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
self
.
runner
.
kv_cache_dtype
)
paged_kv_indptr_tensor_host
=
torch
.
arange
(
0
,
paged_kv_indptr_tensor_host
=
torch
.
arange
(
0
,
batch_size
+
1
,
batch_size
+
1
,
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
...
@@ -349,7 +353,7 @@ class FlashInferMetadata(AttentionMetadata):
...
@@ -349,7 +353,7 @@ class FlashInferMetadata(AttentionMetadata):
self
.
page_size
,
self
.
page_size
,
# Disable flashinfer's pos encoding and use vllm's rope.
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode
=
"NONE"
,
pos_encoding_mode
=
"NONE"
,
)
data_type
=
self
.
data_type
)
def
asdict_zerocopy
(
self
,
def
asdict_zerocopy
(
self
,
skip_fields
:
Optional
[
Set
[
str
]]
=
None
skip_fields
:
Optional
[
Set
[
str
]]
=
None
...
@@ -586,8 +590,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -586,8 +590,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indptr_tensor
=
None
paged_kv_indptr_tensor
=
None
paged_kv_last_page_len_tensor
=
None
paged_kv_last_page_len_tensor
=
None
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
if
self
.
runner
.
kv_cache_dtype
.
startswith
(
"fp8"
):
self
.
runner
.
kv_cache_dtype
,
self
.
runner
.
model_config
.
dtype
)
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
self
.
runner
.
kv_cache_dtype
)
else
:
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
self
.
runner
.
kv_cache_dtype
,
self
.
runner
.
model_config
.
dtype
)
return
FlashInferMetadata
(
return
FlashInferMetadata
(
num_prefills
=
self
.
num_prefills
,
num_prefills
=
self
.
num_prefills
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment