Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e7c1b7f3
Commit
e7c1b7f3
authored
Sep 06, 2024
by
zhuwenwen
Browse files
Merge branch 'v0.5.4-dtk24.04.1'
parents
7462218e
04c62b93
Changes
442
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3125 additions
and
270 deletions
+3125
-270
tests/entrypoints/openai/test_serving_chat.py
tests/entrypoints/openai/test_serving_chat.py
+44
-7
tests/entrypoints/openai/test_tokenization.py
tests/entrypoints/openai/test_tokenization.py
+152
-0
tests/entrypoints/openai/test_vision.py
tests/entrypoints/openai/test_vision.py
+14
-38
tests/kernels/quant_utils.py
tests/kernels/quant_utils.py
+72
-0
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+25
-20
tests/kernels/test_attention_selector.py
tests/kernels/test_attention_selector.py
+15
-10
tests/kernels/test_blocksparse_attention.py
tests/kernels/test_blocksparse_attention.py
+16
-14
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+60
-26
tests/kernels/test_cutlass.py
tests/kernels/test_cutlass.py
+102
-46
tests/kernels/test_encoder_decoder_attn.py
tests/kernels/test_encoder_decoder_attn.py
+953
-0
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+15
-8
tests/kernels/test_flashinfer.py
tests/kernels/test_flashinfer.py
+248
-0
tests/kernels/test_fp8_quant.py
tests/kernels/test_fp8_quant.py
+87
-0
tests/kernels/test_int8_quant.py
tests/kernels/test_int8_quant.py
+10
-18
tests/kernels/test_marlin_gemm.py
tests/kernels/test_marlin_gemm.py
+316
-35
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+3
-3
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+13
-17
tests/kernels/test_sampler.py
tests/kernels/test_sampler.py
+33
-20
tests/kernels/utils.py
tests/kernels/utils.py
+920
-0
tests/lora/conftest.py
tests/lora/conftest.py
+27
-8
No files found.
Too many changes to show.
To preserve performance only
442 of 442+
files are displayed.
Plain diff
Email patch
tests/entrypoints/openai/test_serving_chat.py
View file @
e7c1b7f3
import
asyncio
from
contextlib
import
suppress
from
dataclasses
import
dataclass
from
unittest.mock
import
MagicMock
import
pytest
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
ChatCompletionRequest
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
MODEL_NAME
=
"openai-community/gpt2"
CHAT_TEMPLATE
=
"Dummy chat template for testing {}"
pytestmark
=
pytest
.
mark
.
openai
@
dataclass
class
MockModelConfig
:
...
...
@@ -36,11 +37,47 @@ async def _async_serving_chat_init():
model_config
,
served_model_names
=
[
MODEL_NAME
],
response_role
=
"assistant"
,
chat_template
=
CHAT_TEMPLATE
)
chat_template
=
CHAT_TEMPLATE
,
lora_modules
=
None
,
prompt_adapters
=
None
,
request_logger
=
None
)
return
serving_completion
def
test_async_serving_chat_init
():
serving_completion
=
asyncio
.
run
(
_async_serving_chat_init
())
assert
serving_completion
.
tokenizer
is
not
None
assert
serving_completion
.
tokenizer
.
chat_template
==
CHAT_TEMPLATE
assert
serving_completion
.
chat_template
==
CHAT_TEMPLATE
def
test_serving_chat_should_set_correct_max_tokens
():
mock_engine
=
MagicMock
(
spec
=
AsyncLLMEngine
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
serving_chat
=
OpenAIServingChat
(
mock_engine
,
MockModelConfig
(),
served_model_names
=
[
MODEL_NAME
],
response_role
=
"assistant"
,
chat_template
=
CHAT_TEMPLATE
,
lora_modules
=
None
,
prompt_adapters
=
None
,
request_logger
=
None
)
req
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"what is 1+1?"
}],
guided_decoding_backend
=
"outlines"
,
)
with
suppress
(
Exception
):
asyncio
.
run
(
serving_chat
.
create_chat_completion
(
req
))
# AsyncLLMEngine.generate(inputs, sampling_params, ...)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
93
req
.
max_tokens
=
10
with
suppress
(
Exception
):
asyncio
.
run
(
serving_chat
.
create_chat_completion
(
req
))
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
10
tests/entrypoints/openai/test_tokenization.py
0 → 100644
View file @
e7c1b7f3
import
openai
# use the official client for correctness check
import
pytest
import
requests
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
...utils
import
RemoteOpenAIServer
from
.test_completion
import
zephyr_lora_added_tokens_files
# noqa: F401
from
.test_completion
import
zephyr_lora_files
# noqa: F401
# any model with a chat template should work here
MODEL_NAME
=
"HuggingFaceH4/zephyr-7b-beta"
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
(
zephyr_lora_added_tokens_files
:
str
):
# noqa: F811
args
=
[
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"bfloat16"
,
"--max-model-len"
,
"8192"
,
"--enforce-eager"
,
"--max-num-seqs"
,
"128"
,
# lora config
"--enable-lora"
,
"--lora-modules"
,
f
"zephyr-lora2=
{
zephyr_lora_added_tokens_files
}
"
,
"--max-lora-rank"
,
"64"
,
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
yield
remote_server
@
pytest
.
fixture
(
scope
=
"module"
)
def
tokenizer_name
(
model_name
:
str
,
zephyr_lora_added_tokens_files
:
str
):
# noqa: F811
return
zephyr_lora_added_tokens_files
if
(
model_name
==
"zephyr-lora2"
)
else
model_name
@
pytest
.
fixture
(
scope
=
"module"
)
def
client
(
server
):
return
server
.
get_async_client
()
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name,tokenizer_name"
,
[(
MODEL_NAME
,
MODEL_NAME
),
(
"zephyr-lora2"
,
"zephyr-lora2"
)],
indirect
=
[
"tokenizer_name"
],
)
async
def
test_tokenize_completions
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
tokenizer_name
:
str
):
base_url
=
str
(
client
.
base_url
)[:
-
3
].
strip
(
"/"
)
tokenizer
=
get_tokenizer
(
tokenizer_name
=
tokenizer_name
,
tokenizer_mode
=
"fast"
)
for
add_special
in
[
False
,
True
]:
prompt
=
"vllm1 This is a test prompt."
tokens
=
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
add_special
)
response
=
requests
.
post
(
base_url
+
"/tokenize"
,
json
=
{
"add_special_tokens"
:
add_special
,
"model"
:
model_name
,
"prompt"
:
prompt
})
response
.
raise_for_status
()
assert
response
.
json
()
==
{
"tokens"
:
tokens
,
"count"
:
len
(
tokens
),
"max_model_len"
:
8192
}
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name,tokenizer_name"
,
[(
MODEL_NAME
,
MODEL_NAME
),
(
"zephyr-lora2"
,
"zephyr-lora2"
)],
indirect
=
[
"tokenizer_name"
],
)
async
def
test_tokenize_chat
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
tokenizer_name
:
str
):
base_url
=
str
(
client
.
base_url
)[:
-
3
].
strip
(
"/"
)
tokenizer
=
get_tokenizer
(
tokenizer_name
=
tokenizer_name
,
tokenizer_mode
=
"fast"
)
for
add_generation
in
[
False
,
True
]:
for
add_special
in
[
False
,
True
]:
conversation
=
[{
"role"
:
"user"
,
"content"
:
"Hi there!"
},
{
"role"
:
"assistant"
,
"content"
:
"Nice to meet you!"
},
{
"role"
:
"user"
,
"content"
:
"Can I ask a question? vllm1"
}]
prompt
=
tokenizer
.
apply_chat_template
(
add_generation_prompt
=
add_generation
,
conversation
=
conversation
,
tokenize
=
False
)
tokens
=
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
add_special
)
response
=
requests
.
post
(
base_url
+
"/tokenize"
,
json
=
{
"add_generation_prompt"
:
add_generation
,
"add_special_tokens"
:
add_special
,
"messages"
:
conversation
,
"model"
:
model_name
})
response
.
raise_for_status
()
assert
response
.
json
()
==
{
"tokens"
:
tokens
,
"count"
:
len
(
tokens
),
"max_model_len"
:
8192
}
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name,tokenizer_name"
,
[(
MODEL_NAME
,
MODEL_NAME
),
(
"zephyr-lora2"
,
"zephyr-lora2"
)],
indirect
=
[
"tokenizer_name"
],
)
async
def
test_detokenize
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
tokenizer_name
:
str
):
base_url
=
str
(
client
.
base_url
)[:
-
3
].
strip
(
"/"
)
tokenizer
=
get_tokenizer
(
tokenizer_name
=
tokenizer_name
,
tokenizer_mode
=
"fast"
)
prompt
=
"This is a test prompt. vllm1"
tokens
=
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
False
)
print
(
f
"CALLING
{
base_url
}
FOR
{
model_name
}
"
)
response
=
requests
.
post
(
base_url
+
"/detokenize"
,
json
=
{
"model"
:
model_name
,
"tokens"
:
tokens
})
response
.
raise_for_status
()
assert
response
.
json
()
==
{
"prompt"
:
prompt
}
tests/entrypoints/
test_
openai_vision.py
→
tests/entrypoints/openai
/test
_vision.py
View file @
e7c1b7f3
from
pathlib
import
Path
from
typing
import
Dict
from
typing
import
Dict
,
List
import
openai
import
pytest
import
pytest_asyncio
import
ray
from
vllm.multimodal.utils
import
ImageFetchAiohttp
,
encode_image_base64
from
vllm.multimodal.utils
import
encode_image_base64
,
fetch_image
from
..utils
import
VLLM_PATH
,
RemoteOpenAIServer
from
..
.
utils
import
VLLM_PATH
,
RemoteOpenAIServer
MODEL_NAME
=
"llava-hf/llava-1.5-7b-hf"
LLAVA_CHAT_TEMPLATE
=
(
Path
(
__file__
).
parent
.
parent
.
parent
/
"examples/template_llava.jinja"
)
LLAVA_CHAT_TEMPLATE
=
VLLM_PATH
/
"examples/template_llava.jinja"
assert
LLAVA_CHAT_TEMPLATE
.
exists
()
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS
=
[
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
,
...
...
@@ -22,37 +19,21 @@ TEST_IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png"
,
]
pytestmark
=
pytest
.
mark
.
openai
@
pytest
.
fixture
(
scope
=
"module"
)
def
ray_ctx
():
ray
.
init
(
runtime_env
=
{
"working_dir"
:
VLLM_PATH
})
yield
ray
.
shutdown
()
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
():
return
RemoteOpenAIServer
([
"--model"
,
MODEL_NAME
,
args
=
[
"--dtype"
,
"bfloat16"
,
"--max-model-len"
,
"4096"
,
"--enforce-eager"
,
"--image-input-type"
,
"pixel_values"
,
"--image-token-id"
,
"32000"
,
"--image-input-shape"
,
"1,3,336,336"
,
"--image-feature-size"
,
"576"
,
"--chat-template"
,
str
(
LLAVA_CHAT_TEMPLATE
),
])
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
yield
remote_server
@
pytest
.
fixture
(
scope
=
"module"
)
...
...
@@ -60,11 +41,10 @@ def client(server):
return
server
.
get_async_client
()
@
pytest
_asyncio
.
fixture
(
scope
=
"session"
)
async
def
base64_encoded_image
()
->
Dict
[
str
,
str
]:
@
pytest
.
fixture
(
scope
=
"session"
)
def
base64_encoded_image
()
->
Dict
[
str
,
str
]:
return
{
image_url
:
encode_image_base64
(
await
ImageFetchAiohttp
.
fetch_image
(
image_url
))
image_url
:
encode_image_base64
(
fetch_image
(
image_url
))
for
image_url
in
TEST_IMAGE_URLS
}
...
...
@@ -216,7 +196,7 @@ async def test_chat_streaming_image(client: openai.AsyncOpenAI,
temperature
=
0.0
,
stream
=
True
,
)
chunks
=
[]
chunks
:
List
[
str
]
=
[]
finish_reason_count
=
0
async
for
chunk
in
stream
:
delta
=
chunk
.
choices
[
0
].
delta
...
...
@@ -279,7 +259,3 @@ async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str,
)
completion
=
completion
.
choices
[
0
].
text
assert
completion
is
not
None
and
len
(
completion
)
>=
0
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
tests/kernels/quant_utils.py
0 → 100644
View file @
e7c1b7f3
from
typing
import
Optional
,
Tuple
,
Union
import
torch
def
as_float32_tensor
(
x
:
Union
[
float
,
torch
.
tensor
])
->
torch
.
tensor
:
return
torch
.
as_tensor
(
x
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
def
ref_dynamic_per_token_quant
(
x
:
torch
.
tensor
,
quant_dtype
:
torch
.
dtype
,
scale_ub
:
Optional
[
torch
.
tensor
]
=
None
)
\
->
Tuple
[
torch
.
tensor
,
torch
.
tensor
]:
assert
quant_dtype
in
[
torch
.
int8
,
torch
.
float8_e4m3fn
]
if
scale_ub
is
not
None
:
assert
quant_dtype
==
torch
.
float8_e4m3fn
qtype_traits
=
torch
.
iinfo
(
quant_dtype
)
if
quant_dtype
==
torch
.
int8
\
else
torch
.
finfo
(
quant_dtype
)
qtype_max
=
as_float32_tensor
(
qtype_traits
.
max
)
s_1
=
as_float32_tensor
(
1.0
)
s_512
=
as_float32_tensor
(
512.0
)
# For fp8, in order to match the cuda kernel output, we have to do exactly
# the same operations as in the corresponding fp8 kernel to prevent
# rounding errors.
# Compute scales
x_token_max
,
_
=
x
.
abs
().
max
(
dim
=-
1
)
x_token_max
=
as_float32_tensor
(
x_token_max
)
if
scale_ub
is
not
None
:
x_token_max
=
x_token_max
.
clamp
(
max
=
scale_ub
)
scales
=
(
x_token_max
/
qtype_max
)[:,
None
]
# Quant
if
quant_dtype
==
torch
.
int8
:
iscales
=
as_float32_tensor
(
s_1
/
scales
)
torch_out
=
as_float32_tensor
(
x
)
*
iscales
torch_out
=
torch_out
.
round
()
torch_out
=
torch_out
.
clamp
(
qtype_traits
.
min
,
qtype_traits
.
max
).
to
(
quant_dtype
)
else
:
assert
quant_dtype
==
torch
.
float8_e4m3fn
min_scaling_factor
=
s_1
/
(
qtype_max
*
s_512
)
scales
=
scales
.
clamp
(
min
=
min_scaling_factor
)
torch_out
=
as_float32_tensor
(
x
)
/
scales
torch_out
=
torch_out
.
clamp
(
qtype_traits
.
min
,
qtype_traits
.
max
).
to
(
quant_dtype
)
return
torch_out
,
scales
# The int8 version is very similar. Incorporate the int8 version, like in
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
# kernel
def
ref_dynamic_per_tensor_fp8_quant
(
x
:
torch
.
tensor
)
\
->
Tuple
[
torch
.
tensor
,
torch
.
tensor
]:
fp8_traits
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
=
as_float32_tensor
(
fp8_traits
.
max
)
one
=
as_float32_tensor
(
1.0
)
# For fp8, in order to match the cuda kernel output, we have to do exactly
# the same operations as in the corresponding fp8 kernel to prevent
# rounding errors.
x_max
=
as_float32_tensor
(
x
.
abs
().
max
())
ref_scale
=
x_max
/
fp8_max
ref_iscale
=
one
/
ref_scale
ref_out
=
(
as_float32_tensor
(
x
)
*
ref_iscale
).
clamp
(
fp8_traits
.
min
,
fp8_traits
.
max
).
to
(
dtype
=
torch
.
float8_e4m3fn
)
return
ref_out
,
ref_scale
tests/kernels/test_attention.py
View file @
e7c1b7f3
...
...
@@ -29,7 +29,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
192
,
256
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
]
if
not
is_hip
()
else
[
64
,
80
,
96
,
112
,
128
]
BLOCK_SIZES
=
[
16
,
32
]
...
...
@@ -73,27 +73,27 @@ def ref_single_query_cached_kv_attention(
block_size
=
value_cache
.
shape
[
3
]
num_seqs
=
query
.
shape
[
0
]
block_tables
=
block_tables
.
cpu
().
tolist
()
seq_lens
=
seq_lens
.
cpu
().
tolist
()
block_tables
_lst
=
block_tables
.
cpu
().
tolist
()
seq_lens
_lst
=
seq_lens
.
cpu
().
tolist
()
for
i
in
range
(
num_seqs
):
q
=
query
[
i
].
unsqueeze
(
0
)
block_table
=
block_tables
[
i
]
seq_len
=
int
(
seq_lens
[
i
])
block_table
=
block_tables
_lst
[
i
]
seq_len
=
int
(
seq_lens
_lst
[
i
])
keys
=
[]
values
=
[]
keys
_lst
:
List
[
torch
.
Tensor
]
=
[]
values
_lst
:
List
[
torch
.
Tensor
]
=
[]
for
j
in
range
(
seq_len
):
block_number
=
int
(
block_table
[
j
//
block_size
])
block_offset
=
j
%
block_size
k
=
key_cache
[
block_number
,
:,
:,
block_offset
,
:]
k
=
k
.
reshape
(
num_kv_heads
,
head_size
)
keys
.
append
(
k
)
keys
_lst
.
append
(
k
)
v
=
value_cache
[
block_number
,
:,
:,
block_offset
]
values
.
append
(
v
)
keys
=
torch
.
stack
(
keys
,
dim
=
0
)
values
=
torch
.
stack
(
values
,
dim
=
0
)
values
_lst
.
append
(
v
)
keys
=
torch
.
stack
(
keys
_lst
,
dim
=
0
)
values
=
torch
.
stack
(
values
_lst
,
dim
=
0
)
if
num_queries_per_kv
>
1
:
# Handle MQA and GQA
keys
=
torch
.
repeat_interleave
(
keys
,
num_queries_per_kv
,
dim
=
1
)
...
...
@@ -135,6 +135,8 @@ def test_paged_attention(
seed
:
int
,
device
:
str
,
)
->
None
:
if
kv_cache_dtype
==
"fp8"
and
head_size
%
16
:
pytest
.
skip
()
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
...
...
@@ -158,14 +160,15 @@ def test_paged_attention(
# Create the block tables.
max_num_blocks_per_seq
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
block_tables
=
[]
block_tables
_lst
:
List
[
List
[
int
]]
=
[]
for
_
in
range
(
num_seqs
):
block_table
=
[
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
for
_
in
range
(
max_num_blocks_per_seq
)
]
block_tables
.
append
(
block_table
)
block_tables
=
torch
.
tensor
(
block_tables
,
dtype
=
torch
.
int
)
block_tables_lst
.
append
(
block_table
)
block_tables
=
torch
.
tensor
(
block_tables_lst
,
dtype
=
torch
.
int
)
# Create the KV caches.
key_caches
,
value_caches
=
kv_cache_factory
(
NUM_BLOCKS
,
block_size
,
1
,
...
...
@@ -175,7 +178,7 @@ def test_paged_attention(
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# Using default kv_scale
kv_scale
=
1.0
k
_scale
=
v_scale
=
1.0
# Call the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
...
...
@@ -193,7 +196,8 @@ def test_paged_attention(
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
)
elif
version
==
"v2"
:
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
...
...
@@ -224,7 +228,8 @@ def test_paged_attention(
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
)
else
:
raise
AssertionError
(
f
"Unknown version:
{
version
}
"
)
...
...
@@ -284,7 +289,7 @@ def ref_multi_query_kv_attention(
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
num_seqs
=
len
(
cu_seq_lens
)
-
1
ref_outputs
=
[]
ref_outputs
:
List
[
torch
.
Tensor
]
=
[]
for
i
in
range
(
num_seqs
):
start_idx
=
cu_seq_lens
[
i
]
end_idx
=
cu_seq_lens
[
i
+
1
]
...
...
@@ -304,8 +309,8 @@ def ref_multi_query_kv_attention(
attn_mask
=
attn_mask
,
)
ref_outputs
.
append
(
ref_output
)
ref_output
=
torch
.
cat
(
ref_outputs
,
dim
=
0
)
return
ref_output
return
torch
.
cat
(
ref_output
s
,
dim
=
0
)
# TODO(woosuk): Add tests for USE_ALIBI=True.
...
...
tests/kernels/test_attention_selector.py
View file @
e7c1b7f3
...
...
@@ -9,8 +9,8 @@ from vllm.attention.selector import which_attn_to_use
@
pytest
.
mark
.
parametrize
(
"name"
,
[
"TORCH_SDPA"
,
"ROCM_FLASH"
,
"XFORMERS"
,
"FLASHINFER"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"hip"
,
"cuda"
])
"name"
,
[
"TORCH_SDPA"
,
"ROCM_FLASH"
,
"XFORMERS"
,
"FLASHINFER"
,
"OPENVINO"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"openvino"
,
"hip"
,
"cuda"
])
def
test_env
(
name
:
str
,
device
:
str
,
monkeypatch
):
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
...
...
@@ -28,6 +28,11 @@ def test_env(name: str, device: str, monkeypatch):
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
torch
.
float16
,
16
)
assert
backend
.
name
==
"ROCM_FLASH"
elif
device
==
"openvino"
:
with
patch
(
"vllm.attention.selector.is_openvino"
,
return_value
=
True
):
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
torch
.
float16
,
16
)
assert
backend
.
name
==
"OPENVINO"
else
:
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
torch
.
float16
,
16
)
...
...
@@ -42,36 +47,36 @@ def test_flash_attn(monkeypatch):
# Unsupported CUDA arch
with
patch
(
"torch.cuda.get_device_capability"
,
return_value
=
[
7
,
5
]):
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
None
,
16
)
assert
backend
.
name
!=
"
FLASH_ATTN
"
assert
backend
.
name
!=
STR_
FLASH_ATTN
_VAL
# Unsupported data type
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float8_e4m3fn
,
None
,
16
)
assert
backend
.
name
!=
"
FLASH_ATTN
"
assert
backend
.
name
!=
STR_
FLASH_ATTN
_VAL
# Unsupported kv cache data type
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
"fp8"
,
16
)
assert
backend
.
name
!=
"
FLASH_ATTN
"
assert
backend
.
name
!=
STR_
FLASH_ATTN
_VAL
# Unsupported block size
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
None
,
8
)
assert
backend
.
name
!=
"
FLASH_ATTN
"
assert
backend
.
name
!=
STR_
FLASH_ATTN
_VAL
# Unsupported sliding window
backend
=
which_attn_to_use
(
8
,
16
,
8
,
1
,
torch
.
float16
,
None
,
16
)
assert
backend
.
name
!=
"
FLASH_ATTN
"
assert
backend
.
name
!=
STR_
FLASH_ATTN
_VAL
# flash-attn is not installed
with
patch
.
dict
(
'sys.modules'
,
{
'vllm_flash_attn'
:
None
}):
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
None
,
16
)
assert
backend
.
name
!=
"
FLASH_ATTN
"
assert
backend
.
name
!=
STR_
FLASH_ATTN
_VAL
# Unsupported head size
backend
=
which_attn_to_use
(
8
,
17
,
8
,
None
,
torch
.
float16
,
None
,
16
)
assert
backend
.
name
!=
"
FLASH_ATTN
"
assert
backend
.
name
!=
STR_
FLASH_ATTN
_VAL
def
test_invalid_env
(
monkeypatch
):
"""Throw an exception if the backend name is invalid."""
override_backend_env_variable
(
monkeypatch
,
STR_INVALID_VAL
)
with
pytest
.
raises
(
ValueError
):
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
None
,
16
)
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
None
,
16
)
\ No newline at end of file
tests/kernels/test_blocksparse_attention.py
View file @
e7c1b7f3
...
...
@@ -77,27 +77,27 @@ def ref_single_query_cached_kv_attention(
block_size
=
value_cache
.
shape
[
3
]
num_seqs
=
query
.
shape
[
0
]
block_tables
=
block_tables
.
cpu
().
tolist
()
seq_lens
=
seq_lens
.
cpu
().
tolist
()
block_tables
_lst
=
block_tables
.
cpu
().
tolist
()
seq_lens
_lst
=
seq_lens
.
cpu
().
tolist
()
for
i
in
range
(
num_seqs
):
q
=
query
[
i
].
unsqueeze
(
0
)
block_table
=
block_tables
[
i
]
seq_len
=
int
(
seq_lens
[
i
])
block_table
=
block_tables
_lst
[
i
]
seq_len
=
int
(
seq_lens
_lst
[
i
])
keys
=
[]
values
=
[]
keys
_lst
:
List
[
torch
.
Tensor
]
=
[]
values
_lst
:
List
[
torch
.
Tensor
]
=
[]
for
j
in
range
(
seq_len
):
block_number
=
int
(
block_table
[
j
//
block_size
])
block_offset
=
j
%
block_size
k
=
key_cache
[
block_number
,
:,
:,
block_offset
,
:]
k
=
k
.
reshape
(
num_kv_heads
,
head_size
)
keys
.
append
(
k
)
keys
_lst
.
append
(
k
)
v
=
value_cache
[
block_number
,
:,
:,
block_offset
]
values
.
append
(
v
)
keys
=
torch
.
stack
(
keys
,
dim
=
0
)
values
=
torch
.
stack
(
values
,
dim
=
0
)
values
_lst
.
append
(
v
)
keys
=
torch
.
stack
(
keys
_lst
,
dim
=
0
)
values
=
torch
.
stack
(
values
_lst
,
dim
=
0
)
if
num_queries_per_kv
>
1
:
# Handle MQA and GQA
keys
=
torch
.
repeat_interleave
(
keys
,
num_queries_per_kv
,
dim
=
1
)
...
...
@@ -212,7 +212,7 @@ def test_paged_attention(
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# Using default kv_scale
kv_scale
=
1.0
k
_scale
=
v_scale
=
1.0
tp_rank
=
0
# Call the paged attention kernel.
...
...
@@ -231,7 +231,8 @@ def test_paged_attention(
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
tp_rank
=
tp_rank
,
blocksparse_local_blocks
=
blocksparse_local_blocks
,
blocksparse_vert_stride
=
blocksparse_vert_stride
,
...
...
@@ -267,7 +268,8 @@ def test_paged_attention(
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
tp_rank
=
tp_rank
,
blocksparse_local_blocks
=
blocksparse_local_blocks
,
blocksparse_vert_stride
=
blocksparse_vert_stride
,
...
...
@@ -432,7 +434,7 @@ def test_varlen_blocksparse_attention_prefill(
value
=
torch
.
repeat_interleave
(
value
,
num_queries_per_kv
,
dim
=
1
)
ref_output
=
ref_multi_query_kv_attention
(
cu_seq_lens
,
cu_seq_lens
.
tolist
()
,
query
,
key
,
value
,
...
...
tests/kernels/test_cache.py
View file @
e7c1b7f3
import
random
from
typing
import
Tuple
from
typing
import
List
,
Tuple
import
pytest
import
torch
...
...
@@ -11,7 +11,7 @@ DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS
=
[
42
]
# Arbitrary values for testing
NUM_LAYERS
=
[
1
]
# Arbitrary values for testing
NUM_HEADS
=
[
8
]
# Arbitrary values for testing
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
192
,
256
]
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
# Arbitrary values for testing
...
...
@@ -53,6 +53,8 @@ def test_copy_blocks(
kv_cache_dtype
:
str
,
device
:
str
,
)
->
None
:
if
kv_cache_dtype
==
"fp8"
and
head_size
%
16
:
pytest
.
skip
()
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
...
...
@@ -64,7 +66,7 @@ def test_copy_blocks(
src_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
remainig_blocks
=
list
(
set
(
range
(
num_blocks
))
-
set
(
src_blocks
))
dst_blocks
=
random
.
sample
(
remainig_blocks
,
2
*
num_mappings
)
block_mapping
=
[]
block_mapping
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
i
in
range
(
num_mappings
):
src
=
src_blocks
[
i
]
dst1
=
dst_blocks
[
2
*
i
]
...
...
@@ -125,6 +127,8 @@ def test_reshape_and_cache(
device
:
str
,
kv_cache_dtype
:
str
,
)
->
None
:
if
kv_cache_dtype
==
"fp8"
and
head_size
%
16
:
pytest
.
skip
()
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
...
...
@@ -132,8 +136,8 @@ def test_reshape_and_cache(
torch
.
set_default_device
(
device
)
# Create a random slot mapping.
num_slots
=
block_size
*
num_blocks
slot_mapping
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
)
slot_mapping
_lst
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
_lst
,
dtype
=
torch
.
long
)
qkv
=
torch
.
randn
(
num_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
)
_
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
...
...
@@ -156,11 +160,11 @@ def test_reshape_and_cache(
cloned_value_cache
=
value_cache
.
clone
()
# Using default kv_scale
kv_scale
=
1.0
k
_scale
=
v_scale
=
1.0
# Call the reshape_and_cache kernel.
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
kv_scale
)
kv_cache_dtype
,
k
_scale
,
v_scale
)
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
...
...
@@ -171,12 +175,12 @@ def test_reshape_and_cache(
# Run the reference implementation.
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indicies
=
block_indicies
.
cpu
().
tolist
()
block_indicies
_lst
=
block_indicies
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets
=
block_offsets
.
cpu
().
tolist
()
block_offsets
_lst
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
block_idx
=
block_indicies
[
i
]
block_offset
=
block_offsets
[
i
]
block_idx
=
block_indicies
_lst
[
i
]
block_offset
=
block_offsets
_lst
[
i
]
cloned_key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
=
reshaped_key
[
i
]
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
...
...
@@ -216,8 +220,6 @@ def test_reshape_and_cache_flash(
device
:
str
,
kv_cache_dtype
:
str
,
)
->
None
:
if
kv_cache_dtype
==
"fp8"
:
pytest
.
skip
()
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
...
...
@@ -225,8 +227,10 @@ def test_reshape_and_cache_flash(
# Create a random slot mapping.
num_slots
=
block_size
*
num_blocks
slot_mapping
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
slot_mapping_lst
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping_lst
,
dtype
=
torch
.
long
,
device
=
device
)
qkv
=
torch
.
randn
(
num_tokens
,
3
,
...
...
@@ -247,29 +251,57 @@ def test_reshape_and_cache_flash(
dtype
,
device
=
device
,
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
key_cache
,
value_cache
=
key_caches
[
0
].
contiguous
(
),
value_caches
[
0
].
contiguous
()
del
key_caches
del
value_caches
# Clone the KV caches.
cloned_key_cache
=
key_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
if
kv_cache_dtype
==
"fp8"
:
cloned_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_key_cache
,
key_cache
)
cloned_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_value_cache
,
value_cache
)
else
:
cloned_key_cache
=
key_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
# Using default kv_scale
k_scale
=
v_scale
=
1.0
# Call the reshape_and_cache kernel.
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
)
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
)
# Run the reference implementation.
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
'
floor
'
)
block_indicies
=
block_indicies
.
cpu
().
tolist
()
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"
floor
"
)
block_indicies
_lst
=
block_indicies
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets
=
block_offsets
.
cpu
().
tolist
()
block_offsets
_lst
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
block_idx
=
block_indicies
[
i
]
block_offset
=
block_offsets
[
i
]
block_idx
=
block_indicies
_lst
[
i
]
block_offset
=
block_offsets
_lst
[
i
]
cloned_key_cache
[
block_idx
,
block_offset
,
:,
:]
=
key
[
i
]
cloned_value_cache
[
block_idx
,
block_offset
,
:,
:]
=
value
[
i
]
assert
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
if
kv_cache_dtype
==
"fp8"
:
assert
torch
.
allclose
(
result_key_cache
,
cloned_key_cache
,
atol
=
0.001
,
rtol
=
0.1
)
assert
torch
.
allclose
(
result_value_cache
,
cloned_value_cache
,
atol
=
0.001
,
rtol
=
0.1
)
else
:
assert
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
@
pytest
.
mark
.
parametrize
(
"direction"
,
COPYING_DIRECTION
)
...
...
@@ -298,6 +330,8 @@ def test_swap_blocks(
)
->
None
:
if
kv_cache_dtype
==
"fp8"
and
"cpu"
in
direction
:
pytest
.
skip
()
if
kv_cache_dtype
==
"fp8"
and
head_size
%
16
:
pytest
.
skip
()
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
...
...
tests/kernels/test_cutlass.py
View file @
e7c1b7f3
...
...
@@ -2,36 +2,53 @@
Run `pytest tests/kernels/test_cutlass.py`.
"""
from
typing
import
Type
from
typing
import
Optional
,
Type
import
pytest
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
def
to_fp8
(
tensor
:
torch
.
t
ensor
):
def
to_fp8
(
tensor
:
torch
.
T
ensor
):
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
return
torch
.
round
(
tensor
.
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)).
to
(
dtype
=
torch
.
float8_e4m3fn
)
def
to_int8
(
tensor
:
torch
.
t
ensor
):
def
to_int8
(
tensor
:
torch
.
T
ensor
):
return
torch
.
round
(
tensor
.
clamp
(
min
=-
128
,
max
=
127
)).
to
(
dtype
=
torch
.
int8
)
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
Type
[
torch
.
dtype
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
output
=
(
scale_a
*
(
scale_b
*
(
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
float32
),
b
.
to
(
dtype
=
torch
.
float32
))))).
to
(
out_dtype
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
def
cutlass_fp8_gemm_helper
(
m
:
int
,
n
:
int
,
k
:
int
,
per_token_act_quant
:
bool
,
per_out_channel_weight_quant
:
bool
,
use_bias
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
bfloat16
,
device
:
str
=
"cuda"
):
# Test for a cutlass kernel with per-token activation quantization
...
...
@@ -42,16 +59,19 @@ def cutlass_fp8_gemm_helper(m: int,
m_a_scales
=
m
if
per_token_act_quant
else
1
n_b_scales
=
n
if
per_out_channel_weight_quant
else
1
scale_a
=
(
torch
.
randn
(
(
m_a_scales
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
/
10
)
scale_b
=
(
torch
.
randn
(
(
1
,
n_b_scales
),
device
=
device
,
dtype
=
torch
.
float32
)
/
10
)
scale_a
=
(
torch
.
randn
((
m_a_scales
,
1
),
device
=
device
,
dtype
=
torch
.
float32
))
scale_b
=
(
torch
.
randn
((
1
,
n_b_scales
),
device
=
device
,
dtype
=
torch
.
float32
))
if
use_bias
:
bias
=
torch
.
rand
((
n
,
),
device
=
device
,
dtype
=
out_dtype
)
*
10
else
:
bias
=
None
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
)
baseline
=
torch
.
mm
(
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
),
scale_b
*
b
.
to
(
dtype
=
torch
.
float32
)).
to
(
out_dtype
)
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
1
e-
1
)
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
5
e-
2
)
def
cutlass_int8_gemm_helper
(
m
:
int
,
...
...
@@ -59,6 +79,7 @@ def cutlass_int8_gemm_helper(m: int,
k
:
int
,
per_token_act_quant
:
bool
,
per_out_channel_weight_quant
:
bool
,
use_bias
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
bfloat16
,
device
:
str
=
"cuda"
):
# Test for a cutlass kernel with per-token activation quantization
...
...
@@ -69,79 +90,106 @@ def cutlass_int8_gemm_helper(m: int,
m_a_scales
=
m
if
per_token_act_quant
else
1
n_b_scales
=
n
if
per_out_channel_weight_quant
else
1
scale_a
=
(
torch
.
randn
(
(
m_a_scales
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
/
10
)
scale_b
=
(
torch
.
randn
(
(
1
,
n_b_scales
),
device
=
device
,
dtype
=
torch
.
float32
)
/
10
)
scale_a
=
(
torch
.
randn
(
(
m_a_scales
,
1
),
device
=
device
,
dtype
=
torch
.
float32
))
scale_b
=
(
torch
.
randn
(
(
1
,
n_b_scales
),
device
=
device
,
dtype
=
torch
.
float32
))
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
)
baseline
=
torch
.
mm
(
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
),
scale_b
*
b
.
to
(
dtype
=
torch
.
float32
)).
to
(
dtype
=
out_dtype
)
if
use_bias
:
bias
=
torch
.
rand
((
n
,
),
device
=
device
,
dtype
=
out_dtype
)
*
10
else
:
bias
=
None
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
512
,
222
,
100
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
222
,
100
,
33
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
4096
,
8192
,
16384
,
24576
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
496
,
1024
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
capability
<
89
,
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
):
cutlass_fp8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
)
per_out_ch
:
bool
,
use_bias
:
bool
):
cutlass_fp8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
8192
,
16384
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
496
,
1024
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
):
cutlass_int8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
)
per_out_ch
:
bool
,
use_bias
:
bool
):
cutlass_int8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
]):
cutlass_int8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
out_dtype
)
out_dtype
:
Type
[
torch
.
dtype
],
use_bias
:
bool
):
cutlass_int8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
use_bias
,
out_dtype
=
out_dtype
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
capability
<
89
,
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
]):
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
out_dtype
)
out_dtype
:
Type
[
torch
.
dtype
],
use_bias
:
bool
):
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
use_bias
,
out_dtype
=
out_dtype
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
skipif
(
capability
<
89
,
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_devices
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
device
:
str
):
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
use_bias
:
bool
,
device
:
str
):
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
use_bias
,
torch
.
bfloat16
,
device
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_cutlass_int8_gemm_devices
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
device
:
str
):
cutlass_int8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
torch
.
bfloat16
,
device
)
use_bias
:
bool
,
device
:
str
):
cutlass_int8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
use_bias
,
out_dtype
=
torch
.
bfloat16
,
device
=
device
)
# For the following two tests:
...
...
@@ -151,20 +199,26 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
# kernel must handle any M thrown at it.
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
capability
<
89
,
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
):
def
test_cutlass_fp8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
):
for
nk
in
range
(
32
,
128
,
32
):
for
m
in
range
(
1
,
128
):
cutlass_fp8_gemm_helper
(
m
,
nk
,
nk
,
per_act_token
,
per_out_ch
)
cutlass_fp8_gemm_helper
(
m
,
nk
,
nk
,
per_act_token
,
per_out_ch
,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
def
test_cutlass_int8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
):
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
):
for
nk
in
range
(
32
,
128
,
32
):
for
m
in
range
(
1
,
128
):
cutlass_int8_gemm_helper
(
m
,
nk
,
nk
,
per_act_token
,
per_out_ch
)
cutlass_int8_gemm_helper
(
m
,
nk
,
nk
,
per_act_token
,
per_out_ch
,
use_bias
)
# Test working with a subset of A and B
...
...
@@ -185,9 +239,11 @@ def test_cutlass_subset():
scale_a
,
scale_b
,
out_dtype
=
torch
.
bfloat16
)
baseline
=
torch
.
mm
(
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
),
scale_b
*
b
.
to
(
dtype
=
torch
.
float32
)).
to
(
dtype
=
torch
.
bfloat16
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
=
torch
.
bfloat16
)
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
...
...
tests/kernels/test_encoder_decoder_attn.py
0 → 100644
View file @
e7c1b7f3
"""
Tests:
* E2E test of Encoder attention + Decoder self-attention +
Encoder/decoder cross-attention (collectively
"encoder/decoder attention")
* Confirm enc/dec models will fail for chunked prefill
* Confirm enc/dec models will fail for prefix caching
"""
from
typing
import
NamedTuple
,
Optional
import
pytest
import
torch
from
tests.kernels.utils
import
*
from
tests.kernels.utils
import
make_causal_mask
,
maybe_make_long_tensor
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionBackend
,
AttentionType
from
vllm.attention.backends.utils
import
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from
vllm.utils
import
is_hip
HEAD_SIZES
=
[
64
,
256
]
NUM_HEADS
=
[
1
,
16
]
BATCH_SIZES
=
[
1
,
16
]
BLOCK_SIZES
=
[
16
]
BACKEND_NAMES
=
[
STR_XFORMERS_ATTN_VAL
]
CUDA_DEVICE
=
"cuda:0"
MAX_DEC_SEQ_LENS
=
[
128
]
MAX_ENC_SEQ_LENS
=
[
128
]
# Narrow teest-cases for unsupported-scenario
# tests
HEAD_SIZES_FOR_UNSUPP
=
[
HEAD_SIZES
[
0
]]
class
TestPoint
(
NamedTuple
):
"""
Encapsulates the attributes which define a single invocation
of the test_e2e_enc_dec_attn() test
Attributes:
num_heads: The number of heads in the model.
head_size: Head dimension
backend_name: Name of the backend framework used.
batch_size: Number of samples per batch.
block_size: Size of each block of data processed.
max_dec_seq_len: Maximum sequence length for the decoder.
max_enc_seq_len: Maximum sequence length for the encoder.
num_blocks: Number of blocks in the model.
"""
num_heads
:
int
head_size
:
int
backend_name
:
str
batch_size
:
int
block_size
:
int
max_dec_seq_len
:
int
max_enc_seq_len
:
int
num_blocks
:
int
class
TestResources
(
NamedTuple
):
'''
Encapsulates key components for performing an
encoder/decoder attention test
Note that
(1) attn automatically selects an attention backend
based on platform info & a set of canned
heuristics
(2) attn_backend is thus *not the same backend
instance* used by attn, but rather it is
intended to be a
*different instance* of the *same backend class*;
it is assumed that the user of TestResources
will leverage attn_backend for the purpose of
constructing backend-compatible attention
metadata instances
Attributes:
* scale: 1/sqrt(d) scale factor for attn
* attn_backend: implementatino of abstraction
attention interface using
a particular kernel library
i.e. XFormers
* attn: Attention layer instance
* kv_cache: shared key/value cache for all attention
'''
scale
:
float
attn_backend
:
AttentionBackend
attn
:
Attention
kv_cache
:
torch
.
Tensor
def
_make_test_resources
(
test_pt
:
TestPoint
,
)
->
TestResources
:
'''
Build key components for performing encoder/decoder attention test.
Note that
(1) The Attention instance constructed here, automatically selects
an attention backend class based on platform info & a set of canned
heuristics, so
(2) The attention backend instance constructed here is thus *not
the same backend instance* used by attn, but rather it is
intended to be a *different instance* of the *same backend class*;
therefore,
(3) This function requires that test_pt.backend_name matches the backend
class that Attention will automatically select when it is constructed.
Arguments:
* test_pt: TestPoint data structure; this function relies on the
following fields: num_heads, head_size, num_blocks,
block_size, backend_name
Returns:
* TestResources data structure.
'''
scale
=
float
(
1.0
/
(
test_pt
.
head_size
**
0.5
))
attn_backend
=
make_backend
(
test_pt
.
backend_name
)
attn
=
Attention
(
test_pt
.
num_heads
,
test_pt
.
head_size
,
scale
=
scale
,
)
if
test_pt
.
num_blocks
is
None
or
test_pt
.
num_heads
is
None
:
# Caller does not require a KV cache
return
TestResources
(
scale
,
attn_backend
,
attn
,
None
)
# Construct KV cache
kv_cache
=
make_kv_cache
(
test_pt
.
num_blocks
,
test_pt
.
num_heads
,
test_pt
.
head_size
,
test_pt
.
block_size
,
device
=
CUDA_DEVICE
)
return
TestResources
(
scale
,
attn_backend
,
attn
,
kv_cache
)
def
_encoder_attn_setup
(
test_pt
:
TestPoint
,
test_rsrcs
:
TestResources
,
)
->
PhaseTestParameters
:
'''
Set up test vectors & data structures for encoder attention test.
A triplet of synthetic query/key/value tensors are constructed.
Given this is an encoder attention test, the key & value
sequences will have the same length as the corresponding queries.
The query/key/value tensors are passed to an ideal reference
self-attention implementation to generate an ideal output tensor.
Encoder inference does not populate the KV cache, therefore
no KV cache memory mapping is constructed
Arguments:
* test_pt: TestPoint data structure; this function relies on the
following fields: batch_size, num_heads, head_size,
block_size, max_q_seq_len
* test_rsrcs: TestResources data structure; this function relies on the
scale field
Returns:
* PhaseTestParameters data structure comprising (1) packed query/key/value
tensors, (2) the ideal output of attention computed using a naive
implementation, and (3) KVCache field set to None
'''
(
num_heads
,
head_size
,
_
,
batch_size
,
_
,
_
,
max_q_seq_len
,
_
,
)
=
test_pt
scale
=
test_rsrcs
.
scale
max_kv_seq_len
=
max_q_seq_len
# Make test tensors
qkv_in
,
_
,
_
=
make_qkv
(
batch_size
,
max_q_seq_len
,
max_kv_seq_len
,
num_heads
,
head_size
,
attn_type
=
AttentionType
.
ENCODER
,
device
=
CUDA_DEVICE
)
# Compute correct answer using naive non-causal attention
# implementation
ideal_output
=
ref_masked_attention
(
qkv_in
.
query
,
qkv_in
.
key
,
qkv_in
.
value
,
scale
=
scale
,
q_seq_lens
=
qkv_in
.
q_seq_lens
,
kv_seq_lens
=
qkv_in
.
kv_seq_lens
)
packed_ideal_output
,
_
=
pack_tensor
(
ideal_output
,
qkv_in
.
q_seq_lens
,
device
=
CUDA_DEVICE
)
packed_qkv
=
pack_qkv
(
qkv_in
,
device
=
CUDA_DEVICE
)
return
PhaseTestParameters
(
PackedQKVO
(
packed_qkv
,
packed_ideal_output
),
None
# No KV cache
)
def
_decoder_attn_setup
(
test_pt
:
TestPoint
,
test_rsrcs
:
TestResources
,
block_base_addr
:
int
=
0
,
)
->
Tuple
[
QKVInputs
,
PhaseTestParameters
,
PhaseTestParameters
,
int
]:
'''
Set up test vectors & data structures for self-attention test.
A triplet of synthetic query/key/value tensors are constructed ("baseline"
query/key/value). Given this is a self-attention test, the key & value
sequences will have the same length as the corresponding queries.
"Prefill" query/key/value tensors are derived by masking out the last value
in each baseline query/key/value. These tensors are used to test prefill &
populate KV cache for a subsequent decode test.
"Decode" query/key/value tensors are derived by extracting *only* the last
value from each baseline query/key/value (i.e. complement of the prefill
tensors.) These tensors are used to test decode, conditional on the kv cache
being populated during the prefill test.
The baseline query/key/value tensors are passed to an ideal reference
self-attention implementation to generate a "Baseline" ideal output tensor.
This tensor is split into the "Prefill" ideal output tensor (all but the
last element of each output sequence) and the "Decode" ideal output tensor
(*only* the last element of each output sequence); the "Prefill" and
"Decode" ideal output tensors can be used to validate the prefill and decode
test results, respectively.
This function also constructs the self-attention KV cache memory mapping
(slot mapping and block table), ensuring that the block table starts at
block_base_addr
Arguments:
* test_pt: TestPoint data structure; this function relies on the
following fields: batch_size, num_heads, head_size,
block_size, max_q_seq_len
* test_rsrcs: TestResources data structure; this function relies on the
scale field
* block_base_addr: decoder self-attention block-table base address
Returns:
* qkv: Unpacked (batch_size x padded_seq_len x num_heads x
head_size) query/key/value tensors
* Prefill-phase decoder self-attention PhaseTestParameters data structure,
including (1) packed (number_of_tokens x num_heads x head_size)
query/key/value tensors along with (2) ideal attention output
computed using a naive implementation, and (3) memory-mapping data
structures appropriate for prefill phase.
* Decode-phase decoder self-attention PhaseTestParameters data structure,
including (1) packed (number_of_tokens x num_heads x head_size)
query/key/value tensors along with (2) ideal attention output
computed using a naive implementation, and (3) memory-mapping data
structures appropriate for decode phase.
* max_block_idx: max physical address in decoder self-attention block-table
(intended to be used as the base address for the encoder/
decoder cross-attention block-table, which is not
constructed in this function)
'''
(
num_heads
,
head_size
,
_
,
batch_size
,
block_size
,
max_q_seq_len
,
_
,
_
,
)
=
test_pt
scale
=
test_rsrcs
.
scale
max_kv_seq_len
=
max_q_seq_len
# Build test tensors
(
qkv
,
prefill_qkv
,
decode_qkv
,
)
=
make_qkv
(
batch_size
,
max_q_seq_len
,
max_kv_seq_len
,
num_heads
,
head_size
,
attn_type
=
AttentionType
.
DECODER
,
device
=
CUDA_DEVICE
)
# Compute correct answer using naive attention implementation
# with causal attention mask
causal_mask
=
make_causal_mask
(
max_q_seq_len
,
max_kv_seq_len
).
to
(
CUDA_DEVICE
)
ideal_output
=
ref_masked_attention
(
qkv
.
query
,
qkv
.
key
,
qkv
.
value
,
scale
=
scale
,
custom_mask
=
causal_mask
,
q_seq_lens
=
qkv
.
q_seq_lens
,
kv_seq_lens
=
qkv
.
kv_seq_lens
)
# Split out the prefill- & decode-phase ideal answers & pack them
prefill_ideal_output
=
torch
.
zeros_like
(
ideal_output
)
decode_ideal_output
=
torch
.
zeros_like
(
ideal_output
[:,
0
:
1
])
for
bdx
,
prefill_q_seq_len
in
enumerate
(
prefill_qkv
.
q_seq_lens
):
prefill_ideal_output
[
bdx
,
:
prefill_q_seq_len
]
=
ideal_output
[
bdx
,
:
prefill_q_seq_len
]
decode_ideal_output
[
bdx
,
:]
=
ideal_output
[
bdx
,
prefill_q_seq_len
:(
prefill_q_seq_len
+
1
)]
prefill_packed_ideal_output
,
_
=
pack_tensor
(
prefill_ideal_output
,
prefill_qkv
.
q_seq_lens
,
device
=
CUDA_DEVICE
)
decode_packed_ideal_output
,
_
=
pack_tensor
(
decode_ideal_output
,
[
1
for
_
in
range
(
batch_size
)],
device
=
CUDA_DEVICE
)
# Build prefill- & decode-phase data structures
# for decoder self-attention. Block tables and
# slot mapping must be in a format compatible
# with KV caching & attention kernels
#
# Prefill-phase:
#
# * Empty block-tables tensor
# * Slot-mapping with entries for prompt tokens
#
# Decode-phase:
# * Block-tables tensor with minimum number of blocks
# required by total num. tokens in the entirety of all sequences
# (including both prefill & decode)
# * Slot-mapping with entries for tokens that will be decoded in the
# current decode iteration
#
# Note: the format described above is simply mirroring what ModelRunner
# produces
prefill_block_tables
=
make_empty_block_tables_tensor
(
device
=
CUDA_DEVICE
)
(
decode_block_tables
,
slot_mapping_list
,
max_block_idx
,
)
=
make_block_tables_slot_mapping
(
block_size
,
qkv
.
q_seq_lens
,
device
=
CUDA_DEVICE
,
block_base_addr
=
block_base_addr
)
(
prefill_slot_mapping
,
decode_slot_mapping
,
)
=
split_slot_mapping
(
slot_mapping_list
,
qkv
.
q_seq_lens
,
device
=
CUDA_DEVICE
)
prefill_pckd_qkv
=
pack_qkv
(
prefill_qkv
,
device
=
CUDA_DEVICE
)
decode_pckd_qkv
=
pack_qkv
(
decode_qkv
,
device
=
CUDA_DEVICE
)
return
(
qkv
,
PhaseTestParameters
(
# Prefill test params
PackedQKVO
(
prefill_pckd_qkv
,
prefill_packed_ideal_output
),
KVMemoryMap
(
prefill_block_tables
,
prefill_slot_mapping
)),
PhaseTestParameters
(
# Decode test params
PackedQKVO
(
decode_pckd_qkv
,
decode_packed_ideal_output
),
KVMemoryMap
(
decode_block_tables
,
decode_slot_mapping
)),
max_block_idx
)
def
_enc_dec_cross_attn_setup_reuses_query
(
decoder_qkv
:
QKVInputs
,
encoder_test_params
:
PhaseTestParameters
,
prefill_decoder_phase_test_params
:
PhaseTestParameters
,
test_pt
:
TestPoint
,
test_rsrcs
:
TestResources
,
block_base_addr
:
int
=
0
,
)
->
Tuple
[
PhaseTestParameters
,
PhaseTestParameters
]:
'''
Set up test vectors & data structures for cross-attention test.
A triplet of synthetic cross-attention key/value tensors are constructed
("baseline" key/value). Given this is a cross-attention test, we assume
query tensors were already synthesized for a prior self-attention test and
will be reused for cross-attention. The key & value sequences generated here
may have a different length than the corresponding queries (as is often
the case for cross-attention between decoder and encoder sequences.)
Cross attention key & value tensors do not grow during autoregressive
inference; thus this function obtains a single key/value pair suitable for
both prefill and decode.
The "baseline" query tensor is received as an argument. The "baseline"
query/key/value tensors are passed to an ideal reference cross-attention
implementation to generate a "baseline" ideal output tensor. This tensor is
split into the "Prefill" ideal output tensor (all but the last element of
each output sequence) and the "Decode" ideal output tensor (*only* the last
element of each output sequence); the "Prefill" and "Decode" ideal output
tensors can be used to validate the prefill and decode test results,
respectively.
This function also constructs the cross-attention KV cache memory mapping
(slot mapping and block table), ensuring that the block table starts at
block_base_addr.
Arguments:
* decoder_qkv: pre-existing unpacked (batch_size x padded_seq_len x
num_heads x head_size) decoder self-attention inputs;
this function relies on the query and q_seq_lens
fields
* encoder_test_params: PhaseTestParameters data structure which was
used for encoder inference; KV cache field
is not used by this function
* prefill_decoder_phase_test_params: PhaseTestParameters data structure
used for prefill-phase decoder
self-attention; all fields
including KV cache required
* test_pt: TestPoint data structure; this function relies on the
following fields: batch_size, num_heads, head_size,
block_size, max_q_seq_len
* test_rsrcs: TestResources data structure; this function relies on the
scale field
* block_base_addr: decoder self-attention block-table base address
Returns:
* Prefill-phase encoder/decoder cross-attention PhaseTestParameters data
structure, including (1) packed
(number_of_tokens x num_heads x head_size) query/key/value tensors
along with (2) ideal attention output computed using a
naive implementation, and (3) memory-mapping data structures appropriate
for prefill phase.
* Decode-phase encoder/decoder cross-attention PhaseTestParameters data
structure, including (1) packed
(number_of_tokens x num_heads x head_size) query/key/value tensors
along with (2) ideal attention output computed using a
naive implementation, and (3) memory-mapping data structures appropriate
for decode phase.
'''
assert
encoder_test_params
.
packed_qkvo
.
packed_qkv
is
not
None
assert
prefill_decoder_phase_test_params
.
packed_qkvo
.
packed_qkv
is
not
None
(
num_heads
,
head_size
,
_
,
batch_size
,
block_size
,
max_decoder_seq_len
,
max_encoder_seq_len
,
_
,
)
=
test_pt
scale
=
test_rsrcs
.
scale
decoder_query
=
decoder_qkv
.
query
decoder_seq_lens
=
decoder_qkv
.
q_seq_lens
encoder_seq_lens
=
encoder_test_params
.
packed_qkvo
.
packed_qkv
.
q_seq_lens
prefill_q_seq_lens
=
(
prefill_decoder_phase_test_params
.
packed_qkvo
.
packed_qkv
.
q_seq_lens
)
assert
prefill_q_seq_lens
is
not
None
(
cross_kv
,
_
,
_
,
)
=
make_qkv
(
batch_size
,
max_decoder_seq_len
,
max_encoder_seq_len
,
num_heads
,
head_size
,
force_kv_seq_lens
=
encoder_seq_lens
,
attn_type
=
AttentionType
.
ENCODER_DECODER
,
device
=
CUDA_DEVICE
)
ideal_output
=
ref_masked_attention
(
decoder_query
,
cross_kv
.
key
,
cross_kv
.
value
,
scale
=
scale
,
q_seq_lens
=
decoder_seq_lens
,
kv_seq_lens
=
cross_kv
.
kv_seq_lens
)
prefill_ideal_output
=
torch
.
zeros_like
(
ideal_output
)
decode_ideal_output
=
torch
.
zeros_like
(
ideal_output
[:,
0
:
1
])
for
bdx
,
prefill_q_seq_len
in
enumerate
(
prefill_q_seq_lens
):
prefill_ideal_output
[
bdx
,
:
prefill_q_seq_len
]
=
ideal_output
[
bdx
,
:
prefill_q_seq_len
]
decode_ideal_output
[
bdx
,
:]
=
ideal_output
[
bdx
,
prefill_q_seq_len
:(
prefill_q_seq_len
+
1
)]
prefill_packed_ideal_output
,
_
=
pack_tensor
(
prefill_ideal_output
,
prefill_q_seq_lens
,
device
=
CUDA_DEVICE
)
decode_packed_ideal_output
,
_
=
pack_tensor
(
decode_ideal_output
,
[
1
for
_
in
range
(
batch_size
)],
device
=
CUDA_DEVICE
)
# Build prefill- & decode-phase data structures
# for encoder/decoder cross-attention. Block tables and
# slot mapping must be in a format compatible
# with KV caching & attention kernels
#
# Whereas decoder self-attention extracts relationships between
# equal-length Q/K/V sequences, which mutually grow in length
# with each decoded token, cross-attention relates the Q sequence
# - which grows with each new decoded token - to fixed-length
# K and V sequences derived from the encoder hidden states.
#
# Prefill-phase:
#
# * Empty block-tables tensor
# * Slot-mapping with as many entries as there are tokens in the encoder
# prompt.
#
# Decode-phase:
# * Block-tables tensor with minimum number of blocks to
# accommodate K & V tensors which are equal in lnegth
# to the encoder prompt length
# * Empty slot-mapping tensor (since K & V are fixed in size,
# new decoded tokens are not KV-cached and require no slot-
# mapping)
#
# Note: the format above is simply an extension of what ModelRunner
# produces for decoder-only models
prefill_block_tables
=
make_empty_block_tables_tensor
(
device
=
CUDA_DEVICE
)
decode_slot_mapping
=
make_empty_slot_mapping_tensor
(
device
=
CUDA_DEVICE
)
(
decode_block_tables
,
prefill_slot_mapping_list
,
_
,
)
=
make_block_tables_slot_mapping
(
block_size
,
cross_kv
.
kv_seq_lens
,
block_base_addr
=
block_base_addr
,
device
=
CUDA_DEVICE
)
prefill_slot_mapping
=
maybe_make_long_tensor
(
prefill_slot_mapping_list
,
device
=
CUDA_DEVICE
)
# Packed key/value (query is already provided)
packed_cross_kv
=
pack_qkv
(
cross_kv
,
device
=
CUDA_DEVICE
)
return
(
PhaseTestParameters
(
# Prefill-phase test params
PackedQKVO
(
packed_cross_kv
,
prefill_packed_ideal_output
),
KVMemoryMap
(
prefill_block_tables
,
prefill_slot_mapping
)),
PhaseTestParameters
(
# Decode-phase test params
PackedQKVO
(
None
,
decode_packed_ideal_output
),
KVMemoryMap
(
decode_block_tables
,
decode_slot_mapping
)))
def
_run_encoder_attention_test
(
attn
:
Attention
,
encoder_test_params
:
PhaseTestParameters
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
'''
Run encoder attention.
attn.forward() is passed attn_type=AttentionType.ENCODER in order
to configure the kernel invocation for encoder attention
Requires attn_metadata.num_decode_tokens == 0
(There is no encoder execution in the decode-phase)
Arguments:
* attn: Attention wrapper instance
* encoder_test_params: encoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
query/key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention
Returns:
* Attention.forward() applied to packed {query,key,value} and
& attn_metadata
'''
assert
attn_metadata
.
num_decode_tokens
==
0
attn_type
=
AttentionType
.
ENCODER
packed_qkv
=
encoder_test_params
.
packed_qkvo
.
packed_qkv
assert
packed_qkv
is
not
None
return
attn
.
forward
(
packed_qkv
.
query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
None
,
attn_metadata
,
attn_type
=
attn_type
)
def
_run_decoder_self_attention_test
(
test_rsrcs
:
TestResources
,
decoder_test_params
:
PhaseTestParameters
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
'''
Run decoder self-attention test.
attn.forward() is passed attn_type=AttentionType.DECODER
in order to configure the kernel invocation for decoder self-attention.
Arguments:
* test_rsrcs: TestResources instance; this function relies on the kv_cache
and attn (Attention wrapper instance) fields
* decoder_test_params: decoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
query/key/value fields
* attn_metadata: attention metadata for decoder-self attention
(contains KV cache memory-mapping)
Returns:
* Attention.forward() applied to packed_{query,key,value}, kv_cache
& attn_metadata
'''
attn_type
=
AttentionType
.
DECODER
attn
=
test_rsrcs
.
attn
kv_cache
=
test_rsrcs
.
kv_cache
packed_qkv
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
assert
packed_qkv
is
not
None
return
attn
.
forward
(
packed_qkv
.
query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
kv_cache
,
attn_metadata
,
attn_type
=
attn_type
)
def
_run_encoder_decoder_cross_attention_test
(
test_rsrcs
:
TestResources
,
decoder_test_params
:
PhaseTestParameters
,
cross_test_params
:
Optional
[
PhaseTestParameters
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
'''
Run encoder/decoder cross-attention test.
Via PhaseTestParameters data structures, consumes the same query utilized
for decoder self-attention, plus a key/value specific to cross-attention.
if cross_test_params is None or cross_test_params.packed_qkvo.packed_qkv
is None, this reflects that in decode-phase cross attention there
is no growth in the key and value tensors.
attn.forward() is passed attn_type=AttentionType.ENCODER_DECODER
in order to configure the kernel invocation for encoder/decoder cross-
attention.
Arguments:
* test_rsrcs: TestResources instance; this function relies on the kv_cache
and attn (Attention wrapper instance) fields
* decoder_test_params: decoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
query field
* cross_test_params: encoder/decoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention
Returns:
* Attention.forward() applied to packed_{query,key,value}, kv_cache
& attn_metadata
'''
assert
decoder_test_params
.
packed_qkvo
.
packed_qkv
is
not
None
attn_type
=
AttentionType
.
ENCODER_DECODER
attn
=
test_rsrcs
.
attn
kv_cache
=
test_rsrcs
.
kv_cache
if
cross_test_params
is
None
:
key
=
None
value
=
None
else
:
cross_pckd_qkv
=
cross_test_params
.
packed_qkvo
.
packed_qkv
key
=
(
None
if
cross_pckd_qkv
is
None
else
cross_pckd_qkv
.
key
)
value
=
(
None
if
cross_pckd_qkv
is
None
else
cross_pckd_qkv
.
value
)
return
attn
.
forward
(
decoder_test_params
.
packed_qkvo
.
packed_qkv
.
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
attn_type
=
attn_type
)
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"backend_name"
,
BACKEND_NAMES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"max_dec_seq_len"
,
MAX_DEC_SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"max_enc_seq_len"
,
MAX_ENC_SEQ_LENS
)
def
test_encoder_only
(
num_heads
:
int
,
head_size
:
int
,
backend_name
:
str
,
batch_size
:
int
,
block_size
:
int
,
max_dec_seq_len
:
int
,
max_enc_seq_len
:
int
,
monkeypatch
):
# Force Attention wrapper backend
override_backend_env_variable
(
monkeypatch
,
backend_name
)
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
test_pt
=
TestPoint
(
num_heads
,
head_size
,
backend_name
,
batch_size
,
block_size
,
max_dec_seq_len
,
max_enc_seq_len
,
4096
)
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
test_rsrcs
=
_make_test_resources
(
test_pt
)
# Construct encoder attention test params (only used
# during prefill)
enc_test_params
=
_encoder_attn_setup
(
test_pt
,
test_rsrcs
)
# Shared prefill metadata structure
prephase_attn_metadata
:
AttentionMetadata
=
make_test_metadata
(
test_rsrcs
.
attn_backend
,
True
,
None
,
decoder_test_params
=
None
,
encoder_test_params
=
enc_test_params
,
cross_test_params
=
None
,
device
=
CUDA_DEVICE
)
# PREFILL: encoder attention
enc_pckd_act_out
:
torch
.
Tensor
=
(
_run_encoder_attention_test
(
test_rsrcs
.
attn
,
enc_test_params
,
prephase_attn_metadata
))
# - Is encoder attention result correct?
assert_actual_matches_ideal
(
enc_test_params
,
enc_pckd_act_out
)
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"backend_name"
,
BACKEND_NAMES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"max_dec_seq_len"
,
MAX_DEC_SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"max_enc_seq_len"
,
MAX_ENC_SEQ_LENS
)
def
test_e2e_enc_dec_attn
(
num_heads
:
int
,
head_size
:
int
,
backend_name
:
str
,
batch_size
:
int
,
block_size
:
int
,
max_dec_seq_len
:
int
,
max_enc_seq_len
:
int
,
monkeypatch
,
)
->
None
:
'''
End-to-end encoder/decoder test:
* Construct fake test vectors for (1) encoder attention,
(2) decoder self-attention, and (3) encoder/decoder cross-attention
* Construct (1) attention metadata structure with self- and cross-attention
attributes for prefill-phase, and (2) an analogous attention metadata
structure but for decode-phase
* Test attention steps in the following order
* Encoder attention
* Prefill self-attention
* Prefill cross-attention
* Decode self-attention
* Decode cross-attention
* Besides being reflective of realistic use-cases, this order would
exacerbate any accidental overlap in the self-/cross-attention
block tables, which one hopes to avoid
* Validate output correctness against ideal reference attention
implementation
Block tables are constructed such that cross-attention KV cache is in a
higher, non-intersecting address-space than self-attention KV cache.
Self- and cross-attention share the same query tensor but not the K/V
tensors. Self-attention K/Vs must have the same seq len as Q while
cross-attention K/Vs are allowed to differ in seq len, as is often the case
for cross-attention.
This test utilizes PyTest monkey patching to force the attention backend
via an environment variable.
Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip().
Note on metadata: there is a single attention metadata structure shared by
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
and a single one shared by all decode-phase attention operations
(decoder & enc/dec cross.) This is intended to reflect the behavior
of ModelRunner, which constructs a single attention metadata structure for
each prefill or decode run. A realistic scenario would rely on the
attention backend to utilize the appropriate attention metadata fields
according to the value of attn_metadata.attention_type. Thus, this test is
organized so as to confirm that the backend-under-test can handle a
shared prefill attention metadata structure & a shared decode attention
metadata structure.
'''
# Force Attention wrapper backend
override_backend_env_variable
(
monkeypatch
,
backend_name
)
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
test_pt
=
TestPoint
(
num_heads
,
head_size
,
backend_name
,
batch_size
,
block_size
,
max_dec_seq_len
,
max_enc_seq_len
,
4096
)
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
test_rsrcs
=
_make_test_resources
(
test_pt
)
# Construct encoder attention test params (only used
# during prefill)
enc_test_params
=
_encoder_attn_setup
(
test_pt
,
test_rsrcs
)
# Construct Decoder self-attention prefill-phase & decode-phase
# test params, including query/key/value tensors, decoder self-attention
# memory-mapping. cross_block_base_addr is the uppermost address in the
# decoder self-attention block-table, i.e. a base address which the
# encoder/decoder cross-attention block-table may build downward toward.
(
dec_qkv
,
prephase_dec_test_params
,
decphase_dec_test_params
,
cross_block_base_addr
,
)
=
_decoder_attn_setup
(
test_pt
,
test_rsrcs
)
# Construct encoder/decoder cross-attention prefill-phase & decode-phase
# test params, including key/value tensors, cross-attention memory-mapping
(
prephase_cross_test_params
,
decphase_cross_test_params
,
)
=
_enc_dec_cross_attn_setup_reuses_query
(
dec_qkv
,
enc_test_params
,
prephase_dec_test_params
,
test_pt
,
test_rsrcs
,
block_base_addr
=
cross_block_base_addr
)
# Shared prefill metadata structure
assert
prephase_dec_test_params
.
packed_qkvo
.
packed_qkv
is
not
None
prephase_attn_metadata
:
AttentionMetadata
=
make_test_metadata
(
test_rsrcs
.
attn_backend
,
True
,
prephase_dec_test_params
.
packed_qkvo
.
packed_qkv
.
q_seq_lens
,
decoder_test_params
=
prephase_dec_test_params
,
encoder_test_params
=
enc_test_params
,
cross_test_params
=
prephase_cross_test_params
,
device
=
CUDA_DEVICE
)
# PREFILL: encoder attention
enc_pckd_act_out
=
_run_encoder_attention_test
(
test_rsrcs
.
attn
,
enc_test_params
,
prephase_attn_metadata
)
# - Is encoder attention result correct?
assert_actual_matches_ideal
(
enc_test_params
,
enc_pckd_act_out
)
# PREFILL: decoder self-attention test
prephase_dec_pckd_act_out
=
_run_decoder_self_attention_test
(
test_rsrcs
,
prephase_dec_test_params
,
prephase_attn_metadata
)
# - Is prefill decoder self-attention correct?
assert_actual_matches_ideal
(
prephase_dec_test_params
,
prephase_dec_pckd_act_out
)
# PREFILL: encoder/decoder cross-attention test
prephase_cross_pckd_act_out
=
_run_encoder_decoder_cross_attention_test
(
test_rsrcs
,
prephase_dec_test_params
,
prephase_cross_test_params
,
prephase_attn_metadata
)
# - Is prefill encoder/decoder cross-attention correct?
assert_actual_matches_ideal
(
prephase_cross_test_params
,
prephase_cross_pckd_act_out
)
# DECODE: build decode-phase attention metadata
decphase_attn_metadata
:
AttentionMetadata
=
make_test_metadata
(
test_rsrcs
.
attn_backend
,
False
,
dec_qkv
.
q_seq_lens
,
decoder_test_params
=
decphase_dec_test_params
,
encoder_test_params
=
enc_test_params
,
cross_test_params
=
decphase_cross_test_params
,
device
=
CUDA_DEVICE
)
# DECODE: decoder self-attention test
decphase_dec_pckd_act_out
=
_run_decoder_self_attention_test
(
test_rsrcs
,
decphase_dec_test_params
,
decphase_attn_metadata
)
# - Is decode-phase decoder self-attention correct?
assert_actual_matches_ideal
(
decphase_dec_test_params
,
decphase_dec_pckd_act_out
)
# DECODE: encoder/decoder cross-attention test
decphase_cross_pckd_act_out
=
_run_encoder_decoder_cross_attention_test
(
test_rsrcs
,
decphase_dec_test_params
,
None
,
decphase_attn_metadata
)
# - Is decode-phase encoder/decoder cross-attention correct?
assert_actual_matches_ideal
(
decphase_cross_test_params
,
decphase_cross_pckd_act_out
)
tests/kernels/test_flash_attn.py
View file @
e7c1b7f3
...
...
@@ -20,12 +20,13 @@ def ref_paged_attn(
block_tables
:
torch
.
Tensor
,
scale
:
float
,
sliding_window
:
Optional
[
int
]
=
None
,
soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
num_seqs
=
len
(
query_lens
)
block_tables
=
block_tables
.
cpu
().
numpy
()
_
,
block_size
,
num_kv_heads
,
head_size
=
key_cache
.
shape
outputs
=
[]
outputs
:
List
[
torch
.
Tensor
]
=
[]
start_idx
=
0
for
i
in
range
(
num_seqs
):
query_len
=
query_lens
[
i
]
...
...
@@ -53,6 +54,8 @@ def ref_paged_attn(
(
query_len
+
sliding_window
)
+
1
).
bool
().
logical_not
()
mask
|=
sliding_window_mask
if
soft_cap
is
not
None
:
attn
=
soft_cap
*
torch
.
tanh
(
attn
/
soft_cap
)
attn
.
masked_fill_
(
mask
,
float
(
"-inf"
))
attn
=
torch
.
softmax
(
attn
,
dim
=-
1
).
to
(
v
.
dtype
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn
,
v
)
...
...
@@ -68,13 +71,15 @@ def ref_paged_attn(
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
torch
.
inference_mode
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
torch
.
inference_mode
()
def
test_flash_attn_with_paged_kv
(
kv_lens
:
List
[
Tuple
[
int
,
int
]
],
kv_lens
:
List
[
int
],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
...
...
@@ -108,6 +113,7 @@ def test_flash_attn_with_paged_kv(
causal
=
True
,
block_table
=
block_tables
,
cache_seqlens
=
kv_lens_tensor
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
).
squeeze
(
1
)
ref_output
=
ref_paged_attn
(
...
...
@@ -118,6 +124,7 @@ def test_flash_attn_with_paged_kv(
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
scale
=
scale
,
soft_cap
=
soft_cap
,
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
...
...
@@ -129,7 +136,8 @@ def test_flash_attn_with_paged_kv(
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
torch
.
inference_mode
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
torch
.
inference_mode
()
def
test_varlen_with_paged_kv
(
seq_lens
:
List
[
Tuple
[
int
,
int
]],
num_heads
:
Tuple
[
int
,
int
],
...
...
@@ -137,6 +145,7 @@ def test_varlen_with_paged_kv(
sliding_window
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
...
...
@@ -163,10 +172,6 @@ def test_varlen_with_paged_kv(
head_size
,
dtype
=
dtype
)
value_cache
=
torch
.
randn_like
(
key_cache
)
# Normalize the scale of the key and value caches to mitigate
# numerical instability.
key_cache
/=
head_size
**
0.5
value_cache
/=
head_size
**
0.5
cu_query_lens
=
torch
.
tensor
([
0
]
+
query_lens
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
...
...
@@ -192,6 +197,7 @@ def test_varlen_with_paged_kv(
causal
=
True
,
window_size
=
window_size
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
)
ref_output
=
ref_paged_attn
(
...
...
@@ -203,6 +209,7 @@ def test_varlen_with_paged_kv(
block_tables
=
block_tables
,
scale
=
scale
,
sliding_window
=
sliding_window
,
soft_cap
=
soft_cap
,
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
tests/kernels/test_flashinfer.py
0 → 100644
View file @
e7c1b7f3
from
typing
import
List
,
Optional
,
Tuple
import
flashinfer
import
pytest
import
torch
NUM_HEADS
=
[(
16
,
16
),
(
32
,
8
),
(
64
,
8
)]
HEAD_SIZES
=
[
128
,
256
]
BLOCK_SIZES
=
[
16
,
32
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
NUM_BLOCKS
=
32768
# Large enough to test overflow in index calculation.
def
ref_paged_attn
(
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
query_lens
:
List
[
int
],
kv_lens
:
List
[
int
],
block_tables
:
torch
.
Tensor
,
scale
:
float
,
sliding_window
:
Optional
[
int
]
=
None
,
soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
num_seqs
=
len
(
query_lens
)
block_tables
=
block_tables
.
cpu
().
numpy
()
_
,
block_size
,
num_kv_heads
,
head_size
=
key_cache
.
shape
outputs
:
List
[
torch
.
Tensor
]
=
[]
start_idx
=
0
for
i
in
range
(
num_seqs
):
query_len
=
query_lens
[
i
]
kv_len
=
kv_lens
[
i
]
q
=
query
[
start_idx
:
start_idx
+
query_len
]
q
*=
scale
num_kv_blocks
=
(
kv_len
+
block_size
-
1
)
//
block_size
block_indices
=
block_tables
[
i
,
:
num_kv_blocks
]
k
=
key_cache
[
block_indices
].
view
(
-
1
,
num_kv_heads
,
head_size
)
k
=
k
[:
kv_len
]
v
=
value_cache
[
block_indices
].
view
(
-
1
,
num_kv_heads
,
head_size
)
v
=
v
[:
kv_len
]
if
q
.
shape
[
1
]
!=
k
.
shape
[
1
]:
k
=
torch
.
repeat_interleave
(
k
,
q
.
shape
[
1
]
//
k
.
shape
[
1
],
dim
=
1
)
v
=
torch
.
repeat_interleave
(
v
,
q
.
shape
[
1
]
//
v
.
shape
[
1
],
dim
=
1
)
attn
=
torch
.
einsum
(
"qhd,khd->hqk"
,
q
,
k
).
float
()
empty_mask
=
torch
.
ones
(
query_len
,
kv_len
)
mask
=
torch
.
triu
(
empty_mask
,
diagonal
=
kv_len
-
query_len
+
1
).
bool
()
if
sliding_window
is
not
None
:
sliding_window_mask
=
torch
.
triu
(
empty_mask
,
diagonal
=
kv_len
-
(
query_len
+
sliding_window
)
+
1
).
bool
().
logical_not
()
mask
|=
sliding_window_mask
if
soft_cap
is
not
None
:
attn
=
soft_cap
*
torch
.
tanh
(
attn
/
soft_cap
)
attn
.
masked_fill_
(
mask
,
float
(
"-inf"
))
attn
=
torch
.
softmax
(
attn
,
dim
=-
1
).
to
(
v
.
dtype
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn
,
v
)
outputs
.
append
(
out
)
start_idx
+=
query_len
return
torch
.
cat
(
outputs
,
dim
=
0
)
@
pytest
.
mark
.
parametrize
(
"kv_lens"
,
[[
1328
,
18
,
463
],
[
1
,
54
,
293
,
70
]])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
30.0
,
50.0
])
@
torch
.
inference_mode
def
test_flashinfer_decode_with_paged_kv
(
kv_lens
:
List
[
int
],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
])
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
num_seqs
=
len
(
kv_lens
)
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
max_kv_len
=
max
(
kv_lens
)
scale
=
head_size
**-
0.5
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
key_value_cache
=
torch
.
randn
(
NUM_BLOCKS
,
2
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
key_cache
=
key_value_cache
[:,
0
,
:,
:,
:].
squeeze
(
1
)
value_cache
=
key_value_cache
[:,
1
,
:,
:,
:].
squeeze
(
1
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
NUM_BLOCKS
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
kv_indptr
=
[
0
]
kv_indices
=
[]
kv_last_page_lens
=
[]
for
i
in
range
(
num_seqs
):
seq_len
=
kv_lens
[
i
]
assert
seq_len
>
0
num_blocks
=
(
seq_len
+
block_size
-
1
)
//
block_size
kv_indices
.
extend
(
block_tables
[
i
,
:
num_blocks
])
kv_indptr
.
append
(
kv_indptr
[
-
1
]
+
num_blocks
)
kv_last_page_len
=
seq_len
%
block_size
if
kv_last_page_len
==
0
:
kv_last_page_len
=
block_size
kv_last_page_lens
.
append
(
kv_last_page_len
)
kv_indptr
=
torch
.
tensor
(
kv_indptr
,
dtype
=
torch
.
int32
)
kv_indices
=
torch
.
tensor
(
kv_indices
,
dtype
=
torch
.
int32
)
kv_last_page_lens
=
torch
.
tensor
(
kv_last_page_lens
,
dtype
=
torch
.
int32
)
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
)
wrapper
=
flashinfer
.
\
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
)
wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_lens
,
num_query_heads
,
num_kv_heads
,
head_size
,
block_size
,
"NONE"
,
data_type
=
dtype
)
output
=
wrapper
.
forward
(
query
,
key_value_cache
,
logits_soft_cap
=
soft_cap
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
query_lens
=
[
1
]
*
num_seqs
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
scale
=
scale
,
soft_cap
=
soft_cap
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
[[(
1
,
1328
),
(
5
,
18
),
(
129
,
463
)]])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
30.0
,
50.0
])
@
torch
.
inference_mode
def
test_flashinfer_prefill_with_paged_kv
(
seq_lens
:
List
[
Tuple
[
int
,
int
]],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
])
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
max_kv_len
=
max
(
kv_lens
)
scale
=
head_size
**-
0.5
query
=
torch
.
randn
(
sum
(
query_lens
),
num_query_heads
,
head_size
,
dtype
=
dtype
)
key_value_cache
=
torch
.
randn
(
NUM_BLOCKS
,
2
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
key_cache
=
key_value_cache
[:,
0
,
:,
:,
:].
squeeze
(
1
)
value_cache
=
key_value_cache
[:,
1
,
:,
:,
:].
squeeze
(
1
)
# Normalize the scale of the key and value caches to mitigate
# numerical instability.
key_cache
/=
head_size
**
0.5
value_cache
/=
head_size
**
0.5
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
NUM_BLOCKS
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
qo_indptr
=
[
0
]
kv_indptr
=
[
0
]
kv_indices
=
[]
kv_last_page_lens
=
[]
for
i
in
range
(
num_seqs
):
seq_len
=
kv_lens
[
i
]
assert
seq_len
>
0
num_blocks
=
(
seq_len
+
block_size
-
1
)
//
block_size
kv_indices
.
extend
(
block_tables
[
i
,
:
num_blocks
])
kv_indptr
.
append
(
kv_indptr
[
-
1
]
+
num_blocks
)
kv_last_page_len
=
seq_len
%
block_size
if
kv_last_page_len
==
0
:
kv_last_page_len
=
block_size
kv_last_page_lens
.
append
(
kv_last_page_len
)
qo_indptr
.
append
(
qo_indptr
[
-
1
]
+
query_lens
[
i
])
qo_indptr
=
torch
.
tensor
(
qo_indptr
,
dtype
=
torch
.
int32
)
kv_indptr
=
torch
.
tensor
(
kv_indptr
,
dtype
=
torch
.
int32
)
kv_indices
=
torch
.
tensor
(
kv_indices
,
dtype
=
torch
.
int32
)
kv_last_page_lens
=
torch
.
tensor
(
kv_last_page_lens
,
dtype
=
torch
.
int32
)
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
)
wrapper
=
flashinfer
.
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
)
wrapper
.
begin_forward
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_lens
,
num_query_heads
,
num_kv_heads
,
head_size
,
block_size
,
)
output
=
wrapper
.
forward
(
query
,
key_value_cache
,
logits_soft_cap
=
soft_cap
,
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
query_lens
=
query_lens
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
scale
=
scale
,
soft_cap
=
soft_cap
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
tests/kernels/test_fp8_quant.py
0 → 100644
View file @
e7c1b7f3
import
pytest
import
torch
import
vllm._custom_ops
as
ops
from
tests.kernels.quant_utils
import
(
ref_dynamic_per_tensor_fp8_quant
,
ref_dynamic_per_token_quant
)
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
HIDDEN_SIZES
=
[
1
,
2
,
3
,
4
,
16
,
67
,
768
,
2048
,
5120
,
5137
,
8192
,
8193
]
# Arbitrary values for testing
HIDDEN_SIZES
+=
list
(
range
(
1024
,
1033
))
# vectorized conversion edge cases
NUM_TOKENS
=
[
1
,
7
,
83
,
4096
]
# Arbitrary values for testing
SCALE_UBS
=
[
True
,
False
]
SEEDS
=
[
0
]
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"scale_ub"
,
SCALE_UBS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
test_dynamic_per_token_fp8_quant
(
num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
scale_ub
:
bool
,
seed
:
int
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
rand
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
+
1e-6
# avoid nans
scale_ub
=
torch
.
mean
(
x
).
to
(
dtype
=
torch
.
float32
,
device
=
'cuda'
)
\
if
scale_ub
else
None
ref_out
,
ref_scales
=
ref_dynamic_per_token_quant
(
x
,
torch
.
float8_e4m3fn
,
scale_ub
)
ops_out
,
ops_scales
=
ops
.
scaled_fp8_quant
(
x
,
scale_ub
=
scale_ub
,
use_per_token_if_dynamic
=
True
)
assert
torch
.
allclose
(
ref_scales
,
ops_scales
)
assert
torch
.
allclose
(
ref_out
.
to
(
dtype
=
torch
.
float32
),
ops_out
.
to
(
dtype
=
torch
.
float32
))
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
test_dynamic_per_tensor_fp8_quant
(
num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
rand
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
ref_out
,
ref_scale
=
ref_dynamic_per_tensor_fp8_quant
(
x
)
ops_out
,
ops_scale
=
ops
.
scaled_fp8_quant
(
x
)
assert
torch
.
allclose
(
ref_scale
,
ops_scale
)
assert
torch
.
allclose
(
ref_out
.
to
(
dtype
=
torch
.
float32
),
ops_out
.
to
(
dtype
=
torch
.
float32
))
# Regression test for a case with large activations where an int32 index cannot
# represent the number of elements.
@
torch
.
inference_mode
()
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
def
test_fp8_quant_large
(
seed
:
int
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
num_tokens
=
1024000
# Mistral-Nemo's max_position_embeddings
hidden_size
=
1152
# Smallest hidden_size to reproduce the error
dtype
=
torch
.
bfloat16
x
=
torch
.
rand
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
ref_out
,
scale
=
ref_dynamic_per_tensor_fp8_quant
(
x
)
ops_out
,
_
=
ops
.
scaled_fp8_quant
(
x
,
scale
)
# Minimize memory footprint in this test by freeing x and upconverting
# the outputs in place. (torch.allclose does not support fp8)
del
x
ref_out
=
ref_out
.
to
(
dtype
=
dtype
)
ops_out
=
ops_out
.
to
(
dtype
=
dtype
)
assert
torch
.
allclose
(
ref_out
,
ops_out
)
tests/kernels/test_int8_quant.py
View file @
e7c1b7f3
import
pytest
import
torch
# ruff: noqa: F401
import
vllm._C
from
tests.kernels.quant_utils
import
ref_dynamic_per_token_quant
from
vllm._custom_ops
import
scaled_int8_quant
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
HIDDEN_SIZES
=
[
16
,
67
,
768
,
2048
,
5120
,
5137
,
8192
,
...
...
@@ -21,23 +21,16 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype
:
torch
.
dtype
,
seed
:
int
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
int8_traits
=
torch
.
iinfo
(
torch
.
int8
)
x
=
torch
.
rand
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1000
x_token_max
,
_
=
x
.
max
(
dim
=
1
)
x_token_max
=
x_token_max
.
to
(
dtype
=
torch
.
float32
)
scales
=
(
x_token_max
/
float
(
127.0
))[:,
None
].
to
(
device
=
"cuda"
,
dtype
=
torch
.
float32
)
torch_out
=
(
x
/
scales
).
round
().
clamp
(
int8_traits
.
min
,
int8_traits
.
max
).
to
(
torch
.
int8
)
ops_out
=
torch
.
empty_like
(
x
,
dtype
=
torch
.
int8
,
device
=
"cuda"
)
scales_out
=
torch
.
empty_like
(
scales
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
torch
.
ops
.
_C
.
dynamic_scaled_int8_quant
(
ops_out
,
x
,
scales_out
)
# reference
ref_out
,
ref_scales
=
ref_dynamic_per_token_quant
(
x
,
torch
.
int8
)
# kernel
ops_out
,
ops_scales
=
scaled_int8_quant
(
x
)
assert
torch
.
allclose
(
scales
_out
,
scales
)
assert
torch
.
allclose
(
torch
_out
,
ops
_out
,
assert
torch
.
allclose
(
ops_
scales
,
ref_
scales
)
assert
torch
.
allclose
(
ops
_out
,
ref
_out
,
atol
=
1
)
# big atol to account for rounding errors
...
...
@@ -55,12 +48,11 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
int8_traits
=
torch
.
iinfo
(
torch
.
int8
)
x
=
torch
.
rand
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1000
scale
=
torch
.
tensor
([
scale
],
dtype
=
torch
.
float32
,
device
=
"cuda"
)
out1
=
(
x
/
scale
).
round
().
clamp
(
int8_traits
.
min
,
int8_traits
.
max
).
to
(
torch
.
int8
)
out2
=
torch
.
empty_like
(
x
,
dtype
=
torch
.
int8
)
scale_argument
=
torch
.
tensor
([
scale
],
dtype
=
torch
.
float32
,
device
=
"cuda"
)
out2
,
_
=
scaled_int8_quant
(
x
,
scale
)
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
out2
,
x
,
scale_argument
)
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1
)
# big atol to account for rounding errors
tests/kernels/test_marlin_gemm.py
View file @
e7c1b7f3
...
...
@@ -5,23 +5,33 @@ Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
import
pytest
import
torch
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQ_MARLIN_24_MAX_PARALLEL
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
)
from
vllm.model_executor.layers.quantization.utils.marlin_perms
import
(
marlin_perm
)
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
)
from
vllm.model_executor.layers.quantization.qqq
import
(
MARLIN_QQQ_MAX_PARALLEL
,
MARLIN_QQQ_MIN_THREAD_N
,
MARLIN_QQQ_SUPPORTED_GROUP_SIZES
,
MARLIN_QQQ_SUPPORTED_NUM_BITS
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
MarlinWorkspace
,
compute_max_diff
,
is_marlin_supported
,
marlin_24_quantize
,
marlin_quantize
,
marlin_weights
)
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
MARLIN_SUPPORTED_GROUP_SIZES
,
marlin_make_empty_g_idx
,
marlin_permute_scales
,
query_marlin_supported_quant_types
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
pack_fp8_to_int32
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
MarlinWorkspace
,
awq_marlin_quantize
,
get_weight_perm
,
marlin_quantize
,
marlin_weights
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test_24
import
(
marlin_24_quantize
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq
import
(
# noqa: E501
marlin_qqq_quantize
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
gptq_pack
,
quantize_weights
,
sort_weights
)
awq_pack
,
gptq_pack
,
gptq_quantize_weights
,
quantize_weights
,
sort_weights
)
ACT_ORDER_OPTS
=
[
False
,
True
]
K_FULL_OPTS
=
[
False
,
True
]
USE_FP32_REDUCE_OPTS
=
[
False
,
True
]
MARLIN_K_CHUNKS
=
[
128
]
MARLIN_N_CHUNKS
=
[
64
,
128
,
256
]
...
...
@@ -38,21 +48,29 @@ MNK_FACTORS = [
(
67
,
13
,
11
),
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
def
compute_max_diff
(
output
,
output_ref
):
return
torch
.
mean
(
torch
.
abs
(
output
-
output_ref
))
/
torch
.
mean
(
torch
.
abs
(
output_ref
))
def
rand_data
(
shape
):
return
torch
.
randn
(
shape
,
dtype
=
torch
.
half
,
device
=
"cuda"
)
def
rand_data
(
shape
,
dtype
=
torch
.
float16
):
return
torch
.
randn
(
shape
,
dtype
=
dtype
,
device
=
"cuda"
)
@
pytest
.
mark
.
skipif
(
not
is_
marlin_supported
(
),
@
pytest
.
mark
.
skipif
(
not
is_
quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
query_marlin_supported_quant_types
(
False
))
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
def
test_marlin_repack
(
k_chunk
,
n_chunk
,
num_bits
,
group_size
,
act_order
,
mnk_factors
):
def
test_
gptq_
marlin_repack
(
k_chunk
,
n_chunk
,
quant_type
,
group_size
,
act_order
,
mnk_factors
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
...
...
@@ -77,11 +95,11 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
b_weight
=
rand_data
((
size_k
,
size_n
))
# Quantize (and apply act_order if provided)
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
b_weight
,
num_bits
,
group_size
,
act_order
)
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
gptq_
quantize_weights
(
b_weight
,
quant_type
,
group_size
,
act_order
)
# Pack to GPTQ format
q_w_gptq
=
gptq_pack
(
q_w
,
num
_bits
,
size_k
,
size_n
)
q_w_gptq
=
gptq_pack
(
q_w
,
quant_type
.
size
_bits
,
size_k
,
size_n
)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
...
...
@@ -90,8 +108,9 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
q_w
,
g_idx
,
sort_indices
=
sort_weights
(
q_w
,
g_idx
)
# Pack to Marlin format
marlin_q_w_1
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
marlin_perm
[
num_bits
])
weight_perm
=
get_weight_perm
(
quant_type
.
size_bits
)
marlin_q_w_1
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
quant_type
.
size_bits
,
weight_perm
)
# Run Marlin repack GPU kernel
marlin_q_w_2
=
ops
.
gptq_marlin_repack
(
...
...
@@ -99,30 +118,85 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
sort_indices
,
size_k
,
size_n
,
num_bits
,
quant_type
.
size_bits
,
)
torch
.
cuda
.
synchronize
()
assert
torch
.
allclose
(
marlin_q_w_1
,
marlin_q_w_2
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
query_marlin_supported_quant_types
(
False
))
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
def
test_awq_marlin_repack
(
k_chunk
,
n_chunk
,
quant_type
,
group_size
,
mnk_factors
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Create input
b_weight
=
rand_data
((
size_k
,
size_n
))
# Quantize
w_ref
,
q_w
,
s
,
zp
=
quantize_weights
(
b_weight
,
quant_type
,
group_size
,
zero_points
=
True
)
# Pack to AWQ format
q_w_awq
=
awq_pack
(
q_w
,
quant_type
.
size_bits
,
size_k
,
size_n
)
# Pack to Marlin format
weight_perm
=
get_weight_perm
(
quant_type
.
size_bits
)
marlin_q_w_1
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
quant_type
.
size_bits
,
weight_perm
)
# Run Marlin repack GPU kernel
marlin_q_w_2
=
ops
.
awq_marlin_repack
(
q_w_awq
,
size_k
,
size_n
,
quant_type
.
size_bits
,
)
torch
.
cuda
.
synchronize
()
assert
torch
.
allclose
(
marlin_q_w_1
,
marlin_q_w_2
)
@
pytest
.
mark
.
skipif
(
not
is_
marlin_supported
(
),
@
pytest
.
mark
.
skipif
(
not
is_
quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
query_marlin_supported_quant_types
(
False
))
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"is_k_full"
,
K_FULL_OPTS
)
def
test_marlin_gemm
(
@
pytest
.
mark
.
parametrize
(
"use_fp32_reduce"
,
USE_FP32_REDUCE_OPTS
)
def
test_gptq_marlin_gemm
(
k_chunk
,
n_chunk
,
num_bits
,
quant_type
,
group_size
,
mnk_factors
,
act_order
,
is_k_full
,
use_fp32_reduce
,
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
...
...
@@ -143,7 +217,9 @@ def test_marlin_gemm(
b_weight
=
rand_data
((
size_k
,
size_n
))
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
_
=
marlin_quantize
(
b_weight
,
num_bits
,
group_size
,
act_order
)
b_weight
,
quant_type
,
group_size
,
act_order
)
marlin_zp
=
marlin_make_empty_g_idx
(
marlin_s
.
device
)
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
...
...
@@ -152,14 +228,17 @@ def test_marlin_gemm(
a_input
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_idx
,
sort_indices
,
workspace
.
scratch
,
num_bits
,
quant_type
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
,
is_k_full
=
is_k_full
,
has_zp
=
False
,
use_fp32_reduce
=
use_fp32_reduce
,
)
output_ref
=
torch
.
matmul
(
a_input
,
w_ref
)
...
...
@@ -171,14 +250,15 @@ def test_marlin_gemm(
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_
marlin_supported
(
),
@
pytest
.
mark
.
skipif
(
not
is_
quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_24_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_24_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"
num_bits
"
,
GPTQ_MARLIN_24_SUPPORTED_
NUM_BIT
S
)
@
pytest
.
mark
.
parametrize
(
"
quant_type
"
,
GPTQ_MARLIN_24_SUPPORTED_
QUANT_TYPE
S
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
def
test_marlin_24_gemm
(
k_chunk
,
n_chunk
,
num_bits
,
group_size
,
mnk_factors
):
def
test_gptq_marlin_24_gemm
(
k_chunk
,
n_chunk
,
quant_type
,
group_size
,
mnk_factors
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
...
...
@@ -192,7 +272,7 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
b_weight
=
rand_data
((
size_k
,
size_n
))
(
w_24_ref
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
)
=
marlin_24_quantize
(
b_weight
,
num_bits
,
group_size
)
marlin_24_s
)
=
marlin_24_quantize
(
b_weight
,
quant_type
,
group_size
)
workspace_24
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_MAX_PARALLEL
)
...
...
@@ -205,7 +285,7 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
marlin_24_meta
,
marlin_24_s
,
workspace_24
.
scratch
,
num_bits
,
quant_type
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
...
...
@@ -217,3 +297,204 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
-
1
])
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
def
test_fp8_marlin_gemm
(
k_chunk
,
n_chunk
,
num_bits
,
group_size
,
mnk_factors
,
dtype
,
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
print
(
f
"groupsize =
{
group_size
}
"
)
a_input
=
rand_data
((
size_m
,
size_k
),
dtype
=
dtype
)
b_weight
=
rand_data
((
size_k
,
size_n
),
dtype
=
dtype
)
# WEIGHTS
fp8_weight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
b_weight
,
scale
=
None
)
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight
=
pack_fp8_to_int32
(
fp8_weight
)
# Repack weights to marlin format
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
packed_gptq_qweight
,
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
"cuda"
),
size_k
=
size_k
,
size_n
=
size_n
,
num_bits
=
8
,
)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales
=
weight_scale
.
repeat
(
1
,
size_n
).
to
(
a_input
.
dtype
).
to
(
"cuda"
)
# Permute scales
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
size_k
,
size_n
=
size_n
,
group_size
=-
1
)
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
output
=
ops
.
fp8_marlin_gemm
(
a
=
a_input
,
b_q_weight
=
marlin_qweight
,
b_scales
=
marlin_scales
,
workspace
=
workspace
.
scratch
,
num_bits
=
num_bits
,
size_m
=
a_input
.
shape
[
0
],
size_n
=
b_weight
.
shape
[
1
],
size_k
=
a_input
.
shape
[
1
],
)
output_ref
=
torch
.
matmul
(
a_input
,
b_weight
)
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
query_marlin_supported_quant_types
(
True
))
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"use_fp32_reduce"
,
USE_FP32_REDUCE_OPTS
)
def
test_awq_marlin_gemm
(
k_chunk
,
n_chunk
,
quant_type
,
group_size
,
mnk_factors
,
use_fp32_reduce
,
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
print
(
f
"groupsize =
{
group_size
}
"
)
a_input
=
rand_data
((
size_m
,
size_k
))
b_weight
=
rand_data
((
size_k
,
size_n
))
w_ref
,
marlin_q_w
,
marlin_s
,
marlin_zp
=
awq_marlin_quantize
(
b_weight
,
quant_type
,
group_size
)
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
marlin_q_w
.
device
)
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
marlin_q_w
.
device
)
is_k_full
=
True
has_zp
=
True
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
output
=
ops
.
gptq_marlin_gemm
(
a_input
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_idx
,
sort_indices
,
workspace
.
scratch
,
quant_type
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
=
is_k_full
,
has_zp
=
has_zp
,
use_fp32_reduce
=
use_fp32_reduce
,
)
output_ref
=
torch
.
matmul
(
a_input
,
w_ref
)
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"qqq"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
MARLIN_QQQ_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_QQQ_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
def
test_marlin_qqq_gemm
(
k_chunk
,
n_chunk
,
num_bits
,
group_size
,
mnk_factors
,
):
int8_traits
=
torch
.
iinfo
(
torch
.
int8
)
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
print
(
f
"groupsize =
{
group_size
}
"
)
a_input
=
rand_data
((
size_m
,
size_k
))
b_weight
=
rand_data
((
size_k
,
size_n
))
# Quantize activations
s_a
=
a_input
.
abs
().
max
(
dim
=-
1
,
keepdim
=
True
)[
0
].
div
(
int8_traits
.
max
).
to
(
torch
.
float
)
q_a
=
(
a_input
/
s_a
).
round
().
clamp
(
int8_traits
.
min
,
int8_traits
.
max
).
to
(
torch
.
int8
)
# Quantize weights
w_ref
,
marlin_qqq_q_w
,
marlin_qqq_s_group
,
marlin_qqq_s_channel
=
\
marlin_qqq_quantize
(
b_weight
,
num_bits
,
group_size
)
workspace
=
MarlinWorkspace
(
size_n
,
MARLIN_QQQ_MIN_THREAD_N
,
MARLIN_QQQ_MAX_PARALLEL
)
output
=
ops
.
marlin_qqq_gemm
(
q_a
,
marlin_qqq_q_w
,
s_a
,
marlin_qqq_s_channel
,
marlin_qqq_s_group
,
workspace
.
scratch
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
)
output_ref
=
torch
.
matmul
(
q_a
.
half
()
*
s_a
.
half
(),
w_ref
)
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
\ No newline at end of file
tests/kernels/test_moe.py
View file @
e7c1b7f3
...
...
@@ -29,7 +29,7 @@ def torch_moe(a, w1, w2, score, topk):
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)).
sum
(
dim
=
1
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1024
*
128
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
511
,
1024
])
@
pytest
.
mark
.
parametrize
(
"e"
,
[
8
,
64
])
...
...
@@ -77,8 +77,8 @@ def test_mixtral_moe(dtype: torch.dtype):
for
i
in
range
(
config
.
num_local_experts
):
weights
=
(
hf_moe
.
experts
[
i
].
w1
.
weight
.
data
,
hf_moe
.
experts
[
i
].
w3
.
weight
.
data
)
vllm_moe
.
w13_weight
[
i
][:]
=
torch
.
cat
(
weights
,
dim
=
0
)
vllm_moe
.
w2_weight
[
i
][:]
=
hf_moe
.
experts
[
i
].
w2
.
weight
.
data
vllm_moe
.
experts
.
w13_weight
[
i
][:]
=
torch
.
cat
(
weights
,
dim
=
0
)
vllm_moe
.
experts
.
w2_weight
[
i
][:]
=
hf_moe
.
experts
[
i
].
w2
.
weight
.
data
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs
=
torch
.
randn
((
1
,
64
,
config
.
hidden_size
)).
to
(
dtype
).
to
(
"cuda"
)
...
...
tests/kernels/test_pos_encoding.py
View file @
e7c1b7f3
from
itertools
import
accumulate
,
product
from
typing
import
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
import
pytest
import
torch
...
...
@@ -10,7 +10,7 @@ from .allclose_default import get_default_atol, get_default_rtol
IS_NEOX_STYLE
=
[
True
,
False
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
192
,
256
]
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
]
ROTARY_DIMS
=
[
None
,
32
]
# None means rotary dim == head size
NUM_HEADS
=
[
7
,
17
]
# Arbitrary values for testing
BATCH_SIZES
=
[
1
,
5
]
# Arbitrary values for testing
...
...
@@ -126,7 +126,7 @@ def test_batched_rotary_embedding(
query
,
key
,
offsets
=
torch
.
zeros
(
batch_size
*
seq_len
,
dtype
=
int
,
dtype
=
torch
.
long
,
device
=
device
))
# Compare the results.
assert
torch
.
allclose
(
out_query
,
...
...
@@ -214,20 +214,16 @@ def test_batched_rotary_embedding_multi_lora(
def
test_rope_module_cache
():
MAX_POSITIONS
=
[
123
,
1234
]
BASES
=
[
10000
,
1000000
]
ROPE_SCALINGS
=
[
None
,
{
"type"
:
"linear"
,
"factor"
:
(
1
,
)
},
{
"type"
:
"dynamic"
,
"factor"
:
1
}
]
settings
=
[
HEAD_SIZES
,
ROTARY_DIMS
,
MAX_POSITIONS
,
BASES
,
IS_NEOX_STYLE
,
ROPE_SCALINGS
,
DTYPES
]
rope_setting_id_map
=
{}
ROPE_SCALINGS
=
(
None
,
{
"type"
:
"linear"
,
"factor"
:
(
1
,
)
},
{
"type"
:
"dynamic"
,
"factor"
:
1
})
settings
=
(
HEAD_SIZES
,
ROTARY_DIMS
,
MAX_POSITIONS
,
BASES
,
IS_NEOX_STYLE
,
ROPE_SCALINGS
,
DTYPES
)
rope_setting_id_map
:
Dict
[
str
,
int
]
=
{}
for
setting
in
product
(
*
settings
):
head_size
,
rotary_dim
,
max_position
,
base
,
\
is_neox_stype
,
rope_scaling
,
dtype
=
setting
...
...
tests/kernels/test_sampler.py
View file @
e7c1b7f3
import
gc
from
unittest.mock
import
patch
import
pytest
import
torch
import
triton
import
triton.language
as
tl
from
vllm.model_executor.layers.ops.sample
import
(
MAX_TRITON_N_COLS
,
_uniform_to_exponential
,
get_num_triton_sampler_splits
,
sample
)
from
vllm.model_executor.layers.ops.sample
import
(
_sample_triton
,
_uniform_to_exponential
,
sample
)
from
vllm.model_executor.sampling_metadata
import
SamplingTensors
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.triton_utils.libentry
import
LibEntry
from
vllm.triton_utils.sample
import
(
MAX_TRITON_N_COLS
,
get_num_triton_sampler_splits
)
SINGLE_SPLIT_VOCAB_SIZE
=
32000
# llama/mistral/mixtral vocab size
MULTI_SPLIT_VOCAB_SIZE
=
MAX_TRITON_N_COLS
+
100
...
...
@@ -75,15 +79,20 @@ def test_sample_decoding_only(random_sampling, max_best_of,
seeds
=
torch
.
randint
(
1
,
torch
.
iinfo
(
torch
.
long
).
max
,
(
n_splits
,
bs
),
device
=
"cuda"
).
mul_
(
random_sampling_mask
)
sampled_tokens
,
sampled_logprobs
,
sampled_modified_probs
=
sample
(
probs
=
probs
,
logprobs
=
logprobs
,
sample_indices
=
sample_indices
,
seeds
=
seeds
,
max_best_of
=
max_best_of
,
modify_greedy_probs
=
modify_greedy_probs
,
save_logprobs
=
save_logprobs
,
_save_modified_probs
=
True
)
#The current _sample_triton does not utilize the
# libentry decoration. The purpose of adding this patch is to test
# the correctness of libentry.
with
patch
(
"vllm.model_executor.layers.ops.sample._sample_triton"
,
LibEntry
(
_sample_triton
)):
sampled_tokens
,
sampled_logprobs
,
sampled_modified_probs
=
sample
(
probs
=
probs
,
logprobs
=
logprobs
,
sample_indices
=
sample_indices
,
seeds
=
seeds
,
max_best_of
=
max_best_of
,
modify_greedy_probs
=
modify_greedy_probs
,
save_logprobs
=
save_logprobs
,
_save_modified_probs
=
True
)
assert
sampled_tokens
.
shape
==
(
bs
,
max_best_of
)
for
i
in
range
(
bs
):
assert
torch
.
all
(
sampled_tokens
[
i
]
==
i
*
(
vocab_size
//
bs
))
...
...
@@ -129,6 +138,7 @@ def test_sample_decoding_only(random_sampling, max_best_of,
[
SINGLE_SPLIT_VOCAB_SIZE
,
MULTI_SPLIT_VOCAB_SIZE
])
def
test_sample_prompt_logprobs
(
random_sampling
,
max_best_of
,
modify_greedy_probs
,
seed
,
vocab_size
):
set_random_seed
(
seed
)
prompt_sizes
=
[
16
,
32
,
64
,
128
]
*
2
samples
=
8
...
...
@@ -156,14 +166,17 @@ def test_sample_prompt_logprobs(random_sampling, max_best_of,
seeds
=
torch
.
randint
(
1
,
torch
.
iinfo
(
torch
.
long
).
max
,
(
n_splits
,
samples
),
device
=
"cuda"
).
mul_
(
random_sampling_mask
)
sampled_tokens
,
sampled_logprobs
,
_
=
sample
(
probs
=
probs
,
logprobs
=
logprobs
,
sample_indices
=
sample_indices
,
seeds
=
seeds
,
max_best_of
=
max_best_of
,
modify_greedy_probs
=
modify_greedy_probs
,
save_logprobs
=
True
)
#ditto
with
patch
(
"vllm.model_executor.layers.ops.sample._sample_triton"
,
LibEntry
(
_sample_triton
)):
sampled_tokens
,
sampled_logprobs
,
_
=
sample
(
probs
=
probs
,
logprobs
=
logprobs
,
sample_indices
=
sample_indices
,
seeds
=
seeds
,
max_best_of
=
max_best_of
,
modify_greedy_probs
=
modify_greedy_probs
,
save_logprobs
=
True
)
assert
sampled_tokens
.
shape
==
(
samples
,
max_best_of
)
assert
sampled_logprobs
.
shape
==
(
samples
,
max_best_of
)
for
i
,
t
in
enumerate
(
sample_indices
):
...
...
tests/kernels/utils.py
View file @
e7c1b7f3
"""Kernel test utils"""
import
itertools
import
random
from
numbers
import
Number
from
typing
import
Any
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Union
import
pytest
import
torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.xformers
import
XFormersBackend
from
vllm.utils
import
make_tensor_with_pad
# String name of register which may be set in order to
# force auto-selection of attention backend by Attention
# wrapper
STR_BACKEND_ENV_VAR
:
str
=
"VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL
:
str
=
"FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL
:
str
=
"TORCH_SDPA"
STR_ROCM_FLASH_ATTN_VAL
:
str
=
"ROCM_FLASH"
STR_XFORMERS_ATTN_VAL
:
str
=
"XFORMERS"
STR_FLASH_ATTN_VAL
:
str
=
"FLASH_ATTN"
STR_INVALID_VAL
:
str
=
"INVALID"
class
QKVInputs
(
NamedTuple
):
'''
Data structure for representing unpacked attention inputs,
query/key/values and their sequence lengths.
Attributes:
* {query,key,value}: unpacked (batch_size x padded_seq_len x
num_heads x head_size) attention inputs
* q_seq_lens: query sequence lengths list
* kv_seq_lens: shared key/value sequence lengths list
'''
query
:
torch
.
Tensor
key
:
torch
.
Tensor
value
:
torch
.
Tensor
q_seq_lens
:
List
[
int
]
kv_seq_lens
:
List
[
int
]
class
QKVO
(
NamedTuple
):
'''
Data structure for representing unpacked attention inputs,
alongside unpacked known-correct attention output
Attributes:
* qkv: unpacked (batch_size x padded_seq_len x
num_heads x head_size) attention inputs
* ideal_output: unpacked (batch_size x padded_seq_len x
num_heads x head_size) known-correct attention output
'''
qkv
:
QKVInputs
ideal_output
:
torch
.
Tensor
class
PackedQKVInputs
(
NamedTuple
):
'''
Data structure for representing packed attention inputs
Attributes:
* {query,key,value}: packed (number_of_tokens x num_heads
x head_size) attention inputs
* q_start_loc_list: list of query start locations within packed tensor
* kv_start_loc_list: shared list of key/value start locations within
packed tensor
* q_seq_lens: query sequence lengths list
* kv_seq_lens: shared key/value sequence lengths list
'''
query
:
torch
.
Tensor
key
:
torch
.
Tensor
value
:
torch
.
Tensor
q_start_loc_list
:
Optional
[
List
[
int
]]
kv_start_loc_list
:
Optional
[
List
[
int
]]
q_seq_lens
:
Optional
[
List
[
int
]]
kv_seq_lens
:
Optional
[
List
[
int
]]
class
PackedQKVO
(
NamedTuple
):
'''
Data structure for representing packed attention inputs,
alongside packed known-correct attention output
Attributes:
* packed_qkv: packed (number_of_tokens x num_heads
x head_size) attention inputs
* ideal_output: packed (number_of_tokens x num_heads
x head_size) known-correct attention output
'''
packed_qkv
:
Optional
[
PackedQKVInputs
]
ideal_output
:
torch
.
Tensor
class
KVMemoryMap
(
NamedTuple
):
'''
Data structure for encapsulating KV cache memory mapping.
Attributes:
* block_tables: KV cache block tables
* slot_mapping: mapping of sequence offset to physical address
'''
block_tables
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
class
PhaseTestParameters
(
NamedTuple
):
'''
Data structure for encapsulating the test parameters
for a given test "phase" (prefill or decode phase) and attention
scenario (encoder, decoder-self, encoder/decoder-cross)
Attributes:
* packed_qkvo: packed (number_of_tokens x num_heads
x head_size) attention inputs & known-correct
output
* kv_mmap: KV cache memory mapping, specific to this test phase &
attention scenario
'''
packed_qkvo
:
PackedQKVO
kv_mmap
:
Optional
[
KVMemoryMap
]
def
maybe_make_int_tensor
(
_list
:
Optional
[
List
[
int
]],
device
:
Union
[
torch
.
device
,
str
],
)
->
torch
.
Tensor
:
'''
Convert Python int list to a 1D int torch.Tensor on `device`
Returns:
* If _list is not None: 1D int torch.Tensor on `device`
* None otherwise
'''
return
None
if
_list
is
None
else
torch
.
tensor
(
_list
,
dtype
=
torch
.
int
,
device
=
device
)
def
maybe_make_long_tensor
(
_list
:
Optional
[
List
[
int
]],
device
:
Union
[
torch
.
device
,
str
],
)
->
torch
.
Tensor
:
'''
Convert Python int list to a 1D long torch.Tensor on `device`
Returns:
* If _list is not None: 1D long torch.Tensor on `device`
* None otherwise
'''
return
None
if
_list
is
None
else
torch
.
tensor
(
_list
,
dtype
=
torch
.
long
,
device
=
device
)
def
maybe_max
(
_list
:
Optional
[
List
])
->
Optional
[
Number
]:
'''
Returns:
* If _list is not None: max(_list)
* None otherwise
'''
return
None
if
_list
is
None
else
max
(
_list
)
def
make_causal_mask
(
q_max_seq_len
:
int
,
kv_max_seq_len
:
int
,
)
->
torch
.
Tensor
:
'''
Create a q_max_seq_len x kv_max_seq_len causal mask
Arguments:
* q_max_seq_len: query max seq len
* kv_max_seq_len: key/value max seq len
Returns:
* 2D tensor, q_max_seq_len x kv_max_seq_len
'''
# Create a matrix where entry (i, j) is True if i >= j
mask
=
torch
.
triu
(
torch
.
ones
(
q_max_seq_len
,
kv_max_seq_len
),
diagonal
=
1
)
# Replace True with float('-inf') and False with 0
mask
=
mask
.
masked_fill
(
mask
==
1
,
float
(
'-inf'
)).
masked_fill
(
mask
==
0
,
0.0
)
return
mask
def
override_backend_env_variable
(
mpatch
:
pytest
.
MonkeyPatch
,
backend_name
:
str
)
->
None
:
'''
...
...
@@ -20,3 +219,724 @@ def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
* backend_name: attention backend name to force
'''
mpatch
.
setenv
(
STR_BACKEND_ENV_VAR
,
backend_name
)
def
ref_masked_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
scale
:
float
,
custom_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
q_seq_lens
:
Optional
[
List
]
=
None
,
kv_seq_lens
:
Optional
[
List
]
=
None
)
->
torch
.
Tensor
:
'''
"Golden" masked attention reference. Supports two types of masking:
* Basic attention mask, utilizing {q,kv}_seq_lens args to mask out
padding elements
* Custom attention mask, which can force an arbitrary mask tensor, i.e.
causal
Arguments:
* query: batch_size x q_padded_seq_len x num_heads x head_size
* key: batch_size x kv_padded_seq_len x num_heads x head_size
* value: batch_size x kv_padded_seq_len x num_heads x head_size
* scale: Attention scale factor
* custom_mask: custom attention mask; good place to inject a causal
attention mask
* q_seq_lens: list of unpadded query seq_lens for each batch index
* kv_seq_lens: list of unpadded key/value seq_lens for each batch index
Returns:
* Attention result, batch_size x q_padded_seq_len x num_heads x head_size
'''
assert
q_seq_lens
is
not
None
assert
kv_seq_lens
is
not
None
batch_size
=
query
.
shape
[
0
]
assert
(
len
(
q_seq_lens
)
==
batch_size
)
assert
(
len
(
kv_seq_lens
)
==
batch_size
)
attn_weights
=
scale
*
torch
.
einsum
(
"bqhd,bkhd->bhqk"
,
query
,
key
).
float
()
# Basic attention mask, derived from seq lens
if
(
q_seq_lens
is
not
None
)
or
(
kv_seq_lens
is
not
None
):
attn_mask
=
torch
.
zeros_like
(
attn_weights
)
if
q_seq_lens
is
not
None
:
for
bdx
,
plen
in
enumerate
(
q_seq_lens
):
attn_mask
[
bdx
,
:,
plen
:,
:]
=
-
torch
.
inf
if
kv_seq_lens
is
not
None
:
for
bdx
,
plen
in
enumerate
(
kv_seq_lens
):
attn_mask
[
bdx
,
:,
:,
plen
:]
=
-
torch
.
inf
attn_weights
=
attn_weights
+
attn_mask
.
float
()
# Custom attention mask
if
custom_mask
is
not
None
:
attn_weights
=
attn_weights
+
custom_mask
.
float
()
attn_weights
=
torch
.
softmax
(
attn_weights
,
dim
=-
1
).
to
(
value
.
dtype
)
out
=
torch
.
einsum
(
"bhqk,bkhd->bqhd"
,
attn_weights
,
value
)
return
out
def
make_qkv
(
batch_size
:
int
,
max_q_seq_len
:
int
,
max_kv_seq_len
:
Optional
[
int
],
num_heads
:
int
,
head_size
:
int
,
device
:
Union
[
torch
.
device
,
str
],
force_kv_seq_lens
:
Optional
[
List
[
int
]]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
ENCODER_DECODER
,
force_max_len
:
bool
=
False
,
)
->
Tuple
[
QKVInputs
,
QKVInputs
,
QKVInputs
]:
'''
Construct QKV test tensors for self- and cross-attention.
Generates three query/key/value triplets:
* "Baseline" query/key/value (for input to reference attention function)
* "Prefill" query/key/value (last sequence offset zero'd out, for use as
input to prefill kernel)
* "Decode" query/key/value (only the last sequence offset from baseline,
for use as input to decode kernel)
Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v
seqlens
Arguments:
* batch_size
* max_q_seq_len: max query seq len
* max_kv_seq_len: max key/value seq len
* num_heads
* head_size
* is_encoder_decoder_attn: if True, query seqlen may differ from
key/value seqlen (as is often the case for cross-attention);
o/w, query/key/value seqlens match at each batch index
(max_kv_seq_len is unused)
* force_kv_seq_lens: if not None, overrides kv sequence lengths
* attn_type: encoder, decoder self, or enc/dec cross attention
* force_max_len: if True, all query seqlens are max_q_seq_len; o/w query
seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens
and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False
* device: CPU or CUDA device
Returns:
* Overall QKVInputs structure (containing full unpacked Q/K/V tensors)
* Prefill QKVInputs structure (containing all but the last sequence offset)
* Decode QKVInputs structure (containing all only the last sequence offset)
'''
if
force_max_len
:
q_seq_lens
=
[
max_q_seq_len
for
_
in
range
(
batch_size
)]
else
:
q_seq_lens
=
[
random
.
randint
(
2
,
max_q_seq_len
)
for
_
in
range
(
batch_size
)
]
kv_seq_lens
=
None
if
force_kv_seq_lens
is
not
None
:
kv_seq_lens
=
force_kv_seq_lens
elif
attn_type
!=
AttentionType
.
ENCODER_DECODER
:
# K,V seq lens match Q for self-attention
kv_seq_lens
=
q_seq_lens
else
:
# K,V seq lens are distinct from Q seq lens & random
assert
max_kv_seq_len
is
not
None
if
force_max_len
:
kv_seq_lens
=
[
max_kv_seq_len
]
*
batch_size
else
:
kv_seq_lens
=
[
random
.
randint
(
2
,
max_kv_seq_len
)
for
_
in
range
(
batch_size
)
]
query
=
torch
.
rand
(
(
batch_size
,
max_q_seq_len
,
num_heads
,
head_size
)).
to
(
device
)
key
=
torch
.
rand
(
(
batch_size
,
max_kv_seq_len
,
num_heads
,
head_size
)).
to
(
device
)
value
=
torch
.
rand
(
(
batch_size
,
max_kv_seq_len
,
num_heads
,
head_size
)).
to
(
device
)
prefill_query
=
torch
.
zeros
(
(
batch_size
,
max_q_seq_len
,
num_heads
,
head_size
)).
to
(
device
)
prefill_key
=
torch
.
zeros
(
(
batch_size
,
max_kv_seq_len
,
num_heads
,
head_size
)).
to
(
device
)
prefill_value
=
torch
.
zeros
(
(
batch_size
,
max_kv_seq_len
,
num_heads
,
head_size
)).
to
(
device
)
decode_query
=
torch
.
zeros
(
(
batch_size
,
1
,
num_heads
,
head_size
)).
to
(
device
)
decode_key
=
torch
.
zeros
((
batch_size
,
1
,
num_heads
,
head_size
)).
to
(
device
)
decode_value
=
torch
.
zeros
(
(
batch_size
,
1
,
num_heads
,
head_size
)).
to
(
device
)
for
bdx
,
(
q_seq_len
,
kv_seq_len
)
in
enumerate
(
zip
(
q_seq_lens
,
kv_seq_lens
)):
query
[
bdx
,
q_seq_len
:,
:,
:]
=
0
key
[
bdx
,
kv_seq_len
:,
:,
:]
=
0
value
[
bdx
,
kv_seq_len
:,
:,
:]
=
0
prefill_query
[
bdx
,
0
:(
q_seq_len
-
1
),
:,
:]
=
query
[
bdx
,
0
:(
q_seq_len
-
1
),
:,
:]
prefill_key
[
bdx
,
0
:(
kv_seq_len
-
1
),
:,
:]
=
key
[
bdx
,
0
:(
kv_seq_len
-
1
),
:,
:]
prefill_value
[
bdx
,
0
:(
kv_seq_len
-
1
),
:,
:]
=
value
[
bdx
,
0
:(
kv_seq_len
-
1
),
:,
:]
decode_query
[
bdx
,
:,
:,
:]
=
query
[
bdx
,
(
q_seq_len
-
1
):
q_seq_len
,
:,
:]
decode_key
[
bdx
,
:,
:,
:]
=
key
[
bdx
,
(
kv_seq_len
-
1
):
kv_seq_len
,
:,
:]
decode_value
[
bdx
,
:,
:,
:]
=
value
[
bdx
,
(
kv_seq_len
-
1
):
kv_seq_len
,
:,
:]
prefill_q_seq_lens
=
[
plen
-
1
for
plen
in
q_seq_lens
]
prefill_kv_seq_lens
=
[
plen
-
1
for
plen
in
kv_seq_lens
]
decode_q_seq_lens
=
[
1
for
_
in
q_seq_lens
]
decode_kv_seq_lens
=
[
1
for
_
in
kv_seq_lens
]
return
(
QKVInputs
(
query
,
# Overall QKV inputs
key
,
value
,
q_seq_lens
,
kv_seq_lens
),
QKVInputs
(
prefill_query
,
# Prefill subset of QKV sequences
prefill_key
,
prefill_value
,
prefill_q_seq_lens
,
prefill_kv_seq_lens
),
QKVInputs
(
decode_query
,
# Decode subset of KV sequences
decode_key
,
decode_value
,
decode_q_seq_lens
,
decode_kv_seq_lens
))
def
pack_tensor
(
unpacked_tensor
:
torch
.
Tensor
,
seq_lens
:
List
[
int
],
device
:
Union
[
torch
.
device
,
str
])
->
Tuple
[
torch
.
Tensor
,
List
[
int
]]:
'''
Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
unpadded number_of_tokens x num_heads x head_size tensor, where
number_of_tokens = sum(seq_lens)
Arguments:
* unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size
* seq_lens: list of token counts for each seq
* device: CPU or CUDA device
Returns
* packed_tensor: number_of_tokens x num_heads x head_size
* start_loc_list: start idx of each batch elt in packed_tensor; [0] +
list(itertools.accumulate(seq_lens))
'''
num_tok
=
sum
(
seq_lens
)
num_heads
=
unpacked_tensor
.
shape
[
-
2
]
head_size
=
unpacked_tensor
.
shape
[
-
1
]
start_loc_list
=
[
0
]
+
list
(
itertools
.
accumulate
(
seq_lens
))
packed_tensor
=
torch
.
zeros
((
num_tok
,
num_heads
,
head_size
),
device
=
device
)
for
bdx
,
(
seq_len
,
start_loc
)
in
enumerate
(
zip
(
seq_lens
,
start_loc_list
)):
packed_tensor
[
start_loc
:(
start_loc
+
seq_len
),
:,
:]
=
unpacked_tensor
[
bdx
,
:
seq_len
,
:,
:]
return
packed_tensor
,
start_loc_list
def
pack_qkv
(
qkv
:
QKVInputs
,
device
:
Union
[
torch
.
device
,
str
])
->
PackedQKVInputs
:
'''
Individually pack each of Q, K and V, each with dimensions batch_size x
padded_seq_len x num_heads x head_size, into respective number_of_tokens x
num_heads x head_size tensors.
For Q, number_of_tokens = sum(q_seq_lens).
For K and V, number_of_tokens = sum(kv_seq_lens)
Arguments:
* qkv: Unpacked (batch_size x padded_seq_len x num_heads x head_size)
attention inputs
* device: CPU or CUDA device
Returns
* Packed (number_of_tokens x num_heads x head_size) QKV inputs
derived from unpacked inputs
'''
if
qkv
.
query
is
None
:
packed_query
=
None
q_start_loc_list
=
None
else
:
packed_query
,
q_start_loc_list
=
pack_tensor
(
qkv
.
query
,
qkv
.
q_seq_lens
,
device
=
device
)
packed_key
,
kv_start_loc_list
=
pack_tensor
(
qkv
.
key
,
qkv
.
kv_seq_lens
,
device
=
device
)
packed_value
,
_
=
pack_tensor
(
qkv
.
value
,
qkv
.
kv_seq_lens
,
device
=
device
)
return
PackedQKVInputs
(
packed_query
,
packed_key
,
packed_value
,
q_start_loc_list
,
kv_start_loc_list
,
(
None
if
q_start_loc_list
is
None
else
qkv
.
q_seq_lens
),
qkv
.
kv_seq_lens
)
def
make_backend
(
backend_name
:
str
)
->
AttentionBackend
:
'''
Construct the backend instance determined by the backend_name string
argument.
"XFORMERS" -> construct xformers backend
TODO: other backends
Note: at time of writing the Attention wrapper automatically selects
its own backend for Attention.forward(); so the backend instance which
you generate with this function is not meant to be used for *running*
inference, but rather for generating compatible metadata structures
using backend.make_metadata()
Returns:
* Backend instance
'''
if
backend_name
==
STR_XFORMERS_ATTN_VAL
:
return
XFormersBackend
()
raise
AssertionError
(
f
"Unrecognized backend_name
{
backend_name
}
for unit test"
)
def
_make_metadata_tensors
(
seq_lens
:
Optional
[
List
[
int
]],
context_lens
:
Optional
[
List
[
int
]],
encoder_seq_lens
:
Optional
[
List
[
int
]],
device
:
Union
[
torch
.
device
,
str
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Any
,
Any
,
Optional
[
List
[
int
]],
torch
.
Tensor
,
Optional
[
int
]]:
'''
Build scalar & tensor values required to build attention metadata structure.
Arguments:
* seq_lens: list of token-counts for each decoder input seq
* context_lens: list of context length values for each seq
* encoder_seq_lens: list of token-counts for each encoder input seq
* device: CPU or CUDA device
Returns:
* seq_lens_tensor: decoder seq_lens list, as tensor
* context_lens_tensor: context_lens list, as tensor
* max_context_len: max(context_lens)
* max_seq_len: max(seq_lens)
* seq_start_loc: start idx of each sequence
* max_encoder_seq_len: encoder seq_lens list, as tensor
'''
seq_lens_tensor
=
maybe_make_int_tensor
(
seq_lens
,
device
)
context_lens_tensor
=
maybe_make_int_tensor
(
context_lens
,
device
)
max_context_len
=
maybe_max
(
context_lens
)
max_seq_len
=
maybe_max
(
seq_lens
)
encoder_seq_lens_tensor
=
maybe_make_int_tensor
(
encoder_seq_lens
,
device
)
max_encoder_seq_len
=
(
None
if
encoder_seq_lens
is
None
else
max
(
encoder_seq_lens
))
seq_start_loc
=
None
return
(
seq_lens_tensor
,
context_lens_tensor
,
max_context_len
,
max_seq_len
,
seq_start_loc
,
encoder_seq_lens_tensor
,
max_encoder_seq_len
)
def
make_kv_cache
(
num_blocks
:
int
,
num_heads
:
int
,
head_size
:
int
,
block_size
:
int
,
device
:
Union
[
torch
.
device
,
str
],
default_val
:
float
=
0.0
)
->
torch
.
Tensor
:
'''
Create a fake KV cache.
Arguments:
* num_blocks: number of blocks in the KV cache
* num_heads: number of attention heads
* head_size: head dimension
* block_size: number of offsets within a block
* device: CPU or CUDA device
* default_val: initialization value for KV cache elements
Returns:
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
'''
kv_cache
=
torch
.
rand
(
(
2
,
num_blocks
,
block_size
*
num_heads
*
head_size
)).
to
(
device
)
if
default_val
is
not
None
:
kv_cache
[:,
:,
:]
=
default_val
return
kv_cache
def
_num_tokens_to_min_blocks
(
num_tokens
:
int
,
block_size
:
int
)
->
int
:
'''
Compute the minimum number of blocks required to hold num_tokens tokens,
given block_size
'''
return
(
num_tokens
+
block_size
)
//
block_size
def
make_empty_slot_mapping_tensor
(
device
:
Union
[
torch
.
device
,
str
]):
return
maybe_make_long_tensor
([],
device
)
def
make_empty_block_tables_tensor
(
device
:
Union
[
torch
.
device
,
str
]):
return
torch
.
tensor
([],
device
=
device
)
def
split_slot_mapping
(
slot_mapping_list
:
torch
.
Tensor
,
seq_lens
:
List
[
int
],
device
:
Union
[
torch
.
device
,
str
]):
'''
Split a slot mapping into valid prefill- and decode-phase slot mappings.
Context:
* Your goal is to test (1) prefill of N prompts, with prompt-lengths
{K_i
\\
forall i
\\
in [0,N)}, followed by (2) decoding of a single token
for all N prompts (N tokens total); the resultant sequence lengths
after decode would be {K_i + 1 for i
\\
in [0,N)}
* The test you want to do requires (1) having the prefill slot mapping
for all tokens present during prefill, the number of which is
M =
\\
sum_i{K_i}, and (2) having the decode slot mapping for all N
decoded tokens
This function consumes a single 1D slot mapping, which is the
concatenation of N slot mappings each of length K_i + 1 (corresponding
to the sequence lengths after decode), with a total length of
P =
\\
sum_i{K_i + 1} = M + N
The prefill-phase slot mapping results from excising the (K_i + 1)-th entry
from each of the N subsequences in the slot mapping (i.e. omitting the
decoded token's mapping.)
The N excised entries are appended to obtain the decode-phase slot mapping
Arguments:
* slot_mapping_list: Length-P 1D slot mapping (as List) reflecting all N
post-decode sequences
* seq_lens: List of N post-decode sequence lengths (K_i + 1 in the
description above)
* device: cuda, cpu, etc.
Returns:
* prefill_slot_mapping: Length-M 1D slot mapping (as Tensor)
reflecting all N prefill prompts
* decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting
all N decoded tokens
'''
prefill_slot_mapping
=
[]
decode_slot_mapping
=
[]
base_idx
=
0
for
seq_len
in
seq_lens
:
prefill_slot_mapping
.
extend
(
slot_mapping_list
[
base_idx
:(
base_idx
+
seq_len
-
1
)])
decode_slot_mapping
.
append
(
slot_mapping_list
[
base_idx
+
seq_len
-
1
])
base_idx
+=
seq_len
return
(
maybe_make_long_tensor
(
prefill_slot_mapping
,
device
),
maybe_make_long_tensor
(
decode_slot_mapping
,
device
))
def
make_block_tables_slot_mapping
(
block_size
:
int
,
seq_lens
:
List
[
int
],
device
:
Union
[
torch
.
device
,
str
],
block_base_addr
:
int
=
0
)
->
Tuple
[
torch
.
Tensor
,
List
[
int
],
int
]:
'''
Construct fake block tables & slot mappings.
For a sequence with num_tokens tokens the minimum number
of required KV cache blocks is
num_blocks = (num_tokens + block_size) // block_size
Then the minimum KV cache size in blocks is
total_cache_blocks = sum(num_blocks for all seqs)
Then, the blocktable mapping counts downward from
block_base_addr + total_cache_blocks
to
block_base_addr
The constructed block-tables and slot-mapping are sized to the
lengths of the sequences in their entirety (as reflected by seq_lens),
i.e. the total of prefill prompt tokens + decoded tokens.
Arguments:
* block_size: number of offsets per block
* seq_lens: list of token-counts for each sequence
* block_base_addr: the block table base address
* device: CPU or CUDA device
Return:
* block_tables_tensor: block table for sequence
* slot_mapping_list: slot mapping for sequence
* max_block_idx: the highest block address within this block table
'''
# Provision minimum number of KV cache blocks
num_blocks_list
=
[
_num_tokens_to_min_blocks
(
num_tokens
,
block_size
)
for
num_tokens
in
seq_lens
]
max_block_table_len
=
max
(
num_blocks_list
)
block_table_pad_tokens
=
10
block_tables
=
[]
slot_mapping_list
=
[]
# Compute uppermost address of block table
total_cache_blocks
=
sum
(
num_blocks_list
)
block_base_idx
=
block_base_addr
+
total_cache_blocks
max_block_idx
=
block_base_idx
for
sdx
,
num_tokens
in
enumerate
(
seq_lens
):
num_blocks
=
num_blocks_list
[
sdx
]
block_table
=
list
(
range
(
block_base_idx
,
block_base_idx
-
num_blocks
,
-
1
))
for
idx
in
range
(
num_tokens
):
mapping_value
=
(
idx
%
block_size
)
+
block_table
[
idx
//
block_size
]
*
block_size
slot_mapping_list
.
append
(
mapping_value
)
block_base_idx
-=
num_blocks
block_tables
.
append
(
block_table
)
block_tables_tensor
=
make_tensor_with_pad
(
block_tables
,
max_len
=
max_block_table_len
+
block_table_pad_tokens
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
device
,
)
return
(
block_tables_tensor
,
slot_mapping_list
,
max_block_idx
)
def
make_test_metadata
(
attn_backend
:
AttentionBackend
,
is_prompt
:
bool
,
seq_lens
:
Optional
[
List
[
int
]],
decoder_test_params
:
Optional
[
PhaseTestParameters
],
device
:
Union
[
torch
.
device
,
str
],
encoder_test_params
:
Optional
[
PhaseTestParameters
]
=
None
,
cross_test_params
:
Optional
[
PhaseTestParameters
]
=
None
)
->
AttentionMetadata
:
'''
Construct fake attention metadata for a given test phase
(prefill-phase or decode-phase).
encoder_test_params and cross_test_params arguments allow encoder
attention and enc/dec cross-attention (respectively) to use distinct
metadata values from decoder self-attention (decoder_test_params.)
if encoder_test_params and cross_test_params are None, the attention
metadata will support decoder-only scenario.
Assumptions:
* No chunked prefill -> a batch is 100% prefill or 100% decode, never both
Arguments:
* attn_backend: Backend for sourcing attention kernels
* is_prompt: prefill if True, o/w decode
* seq_lens: list of token counts for each sequence
* decoder_test_params: decoder self-attention test params;
this function requires
kv_mmap (memory mapping) field
* device: CPU or CUDA device
* encoder_test_params: encoder attention test params;
this function requires encoder query
sequence lengths field. If None,
encoder query sequence lengths are
treated as None
* cross_test_params: enc/dec cross-attention test params;
this function requires kv_mmap field.
If None, KV cache memory map data
structures are treated as None
Return:
* AttentionMetadata structure
'''
# Decoder self-attention memory mapping
# decoder_test_params is None signals encoder-only
# scenario, so kv_mmap is None
kv_mmap
=
(
None
if
decoder_test_params
is
None
else
decoder_test_params
.
kv_mmap
)
# This function constructs metadata assuming no chunked prefill,
# i.e. 100% prefill tokens or 100% decode tokens
#
# - If is_prompt, num_prefills_or_decodes is the number of prefills
# and num_prefill_or_decode_tokens is the number of prefill tokens
# - If not is_prompt, num_prefills_or_decodes is the number of decodes
# and num_prefill_or_decode_tokens is the number of decode tokens
#
# seq_lens is None signals encoder-only
# scenario, in which case num_prefills_or_decodes and
# num_prefill_or_decode_tokens are unused
num_prefills_or_decodes
=
(
None
if
seq_lens
is
None
else
len
(
seq_lens
))
num_prefill_or_decode_tokens
=
(
None
if
seq_lens
is
None
else
(
sum
(
seq_lens
)
if
is_prompt
else
len
(
seq_lens
)))
# Seems for non-prefix-caching scenarios context_lens
# is never needed
context_lens
=
None
if
encoder_test_params
is
None
:
encoder_seq_lens
=
None
num_encoder_tokens
=
None
else
:
# Encoder/decoder or encoder-only models only:
# * Extract encoder input sequence lengths
assert
encoder_test_params
.
packed_qkvo
.
packed_qkv
is
not
None
encoder_seq_lens
=
encoder_test_params
.
packed_qkvo
.
packed_qkv
.
q_seq_lens
num_encoder_tokens
=
(
None
if
encoder_seq_lens
is
None
else
(
sum
(
encoder_seq_lens
)))
if
cross_test_params
is
None
:
cross_kv_mmap
=
None
else
:
# Encoder/decoder or encoder-only models only:
# * Extract *cross-attention* slot_mapping and block table
# (kv_mmap)
cross_kv_mmap
=
cross_test_params
.
kv_mmap
if
is_prompt
:
# Prefill-phase scenario
num_prefills
=
num_prefills_or_decodes
num_prefill_tokens
=
num_prefill_or_decode_tokens
num_decode_tokens
=
0
(
seq_lens_tensor
,
context_lens_tensor
,
_
,
_
,
_
,
encoder_seq_lens_tensor
,
max_encoder_seq_len
,
)
=
_make_metadata_tensors
(
seq_lens
,
context_lens
,
encoder_seq_lens
,
device
=
device
)
return
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
slot_mapping
=
(
None
if
kv_mmap
is
None
else
kv_mmap
.
slot_mapping
),
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_prefill_seq_len
=
None
if
seq_lens
is
None
else
max
(
seq_lens
),
max_decode_seq_len
=
0
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
(
None
if
kv_mmap
is
None
else
kv_mmap
.
block_tables
),
use_cuda_graph
=
False
,
num_encoder_tokens
=
num_encoder_tokens
,
encoder_seq_lens
=
encoder_seq_lens
,
encoder_seq_lens_tensor
=
encoder_seq_lens_tensor
,
max_encoder_seq_len
=
max_encoder_seq_len
,
cross_slot_mapping
=
(
None
if
cross_kv_mmap
is
None
else
cross_kv_mmap
.
slot_mapping
),
cross_block_tables
=
(
None
if
cross_kv_mmap
is
None
else
cross_kv_mmap
.
block_tables
))
else
:
# not is_prompt
# Decode-phase scenario
assert
kv_mmap
is
not
None
assert
num_prefill_or_decode_tokens
is
not
None
assert
seq_lens
is
not
None
num_prefills
=
0
num_prefill_tokens
=
0
num_decode_tokens
=
num_prefill_or_decode_tokens
(
seq_lens_tensor
,
context_lens_tensor
,
_
,
_
,
_
,
encoder_seq_lens_tensor
,
max_encoder_seq_len
,
)
=
_make_metadata_tensors
(
seq_lens
,
context_lens
,
encoder_seq_lens
,
device
=
device
)
return
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
slot_mapping
=
kv_mmap
.
slot_mapping
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
max
(
seq_lens
),
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
kv_mmap
.
block_tables
,
use_cuda_graph
=
False
,
num_encoder_tokens
=
num_encoder_tokens
,
encoder_seq_lens
=
encoder_seq_lens
,
encoder_seq_lens_tensor
=
encoder_seq_lens_tensor
,
max_encoder_seq_len
=
max_encoder_seq_len
,
cross_slot_mapping
=
(
None
if
cross_kv_mmap
is
None
else
cross_kv_mmap
.
slot_mapping
),
cross_block_tables
=
(
None
if
cross_kv_mmap
is
None
else
cross_kv_mmap
.
block_tables
))
def
assert_actual_matches_ideal
(
test_params
:
PhaseTestParameters
,
output_under_test
:
torch
.
Tensor
)
->
None
:
'''
Assert that observed output matches the ideal output
contained in the test parameters data structure.
Arguments:
* test_params: Test parameters including packed ideal output
* output_under_test: actually observed output value
'''
ideal_output
=
test_params
.
packed_qkvo
.
ideal_output
assert
torch
.
allclose
(
ideal_output
,
output_under_test
.
view_as
(
ideal_output
))
tests/lora/conftest.py
View file @
e7c1b7f3
...
...
@@ -2,6 +2,7 @@ import contextlib
import
gc
import
tempfile
from
collections
import
OrderedDict
from
typing
import
Dict
,
List
,
TypedDict
from
unittest.mock
import
MagicMock
,
patch
import
pytest
...
...
@@ -24,7 +25,18 @@ from vllm.model_executor.layers.sampler import Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader
import
get_model
LONG_LORA_INFOS
=
[{
class
ContextIDInfo
(
TypedDict
):
lora_id
:
int
context_length
:
str
class
ContextInfo
(
TypedDict
):
lora
:
str
context_length
:
str
LONG_LORA_INFOS
:
List
[
ContextIDInfo
]
=
[{
"lora_id"
:
1
,
"context_length"
:
"16k"
,
},
{
...
...
@@ -147,13 +159,21 @@ def dummy_model_gate_up() -> nn.Module:
@
pytest
.
fixture
(
scope
=
"session"
)
def
sql_lora_files
():
return
snapshot_download
(
repo_id
=
"yard1/llama-2-7b-sql-lora-test"
)
def
sql_lora_huggingface_id
():
# huggingface repo id is used to test lora runtime downloading.
return
"yard1/llama-2-7b-sql-lora-test"
@
pytest
.
fixture
(
scope
=
"session"
)
def
sql_lora_files
(
sql_lora_huggingface_id
):
return
snapshot_download
(
repo_id
=
sql_lora_huggingface_id
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
mixtral_lora_files
():
return
snapshot_download
(
repo_id
=
"terrysun/mixtral-lora-adapter"
)
# Note: this module has incorrect adapter_config.json to test
# https://github.com/vllm-project/vllm/pull/5909/files.
return
snapshot_download
(
repo_id
=
"SangBinCho/mixtral-lora"
)
@
pytest
.
fixture
(
scope
=
"session"
)
...
...
@@ -207,7 +227,7 @@ def long_context_infos(long_context_lora_files_16k_1,
long_context_lora_files_16k_2
,
long_context_lora_files_32k
):
cleanup
()
infos
=
{}
infos
:
Dict
[
int
,
ContextInfo
]
=
{}
for
lora_checkpoint_info
in
LONG_LORA_INFOS
:
lora_id
=
lora_checkpoint_info
[
"lora_id"
]
if
lora_id
==
1
:
...
...
@@ -226,7 +246,7 @@ def long_context_infos(long_context_lora_files_16k_1,
@
pytest
.
fixture
def
llama_2_7b_engine_extra_embeddings
()
->
nn
.
Module
:
def
llama_2_7b_engine_extra_embeddings
():
cleanup
()
get_model_old
=
get_model
...
...
@@ -244,7 +264,6 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
@
pytest
.
fixture
def
llama_2_7b_model_extra_embeddings
(
llama_2_7b_engine_extra_embeddings
)
->
nn
.
Module
:
def
llama_2_7b_model_extra_embeddings
(
llama_2_7b_engine_extra_embeddings
):
yield
(
llama_2_7b_engine_extra_embeddings
.
model_executor
.
driver_worker
.
model_runner
.
model
)
Prev
1
…
12
13
14
15
16
17
18
19
20
…
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment