Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e7c1b7f3
Commit
e7c1b7f3
authored
Sep 06, 2024
by
zhuwenwen
Browse files
Merge branch 'v0.5.4-dtk24.04.1'
parents
7462218e
04c62b93
Changes
442
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3125 additions
and
270 deletions
+3125
-270
tests/entrypoints/openai/test_serving_chat.py
tests/entrypoints/openai/test_serving_chat.py
+44
-7
tests/entrypoints/openai/test_tokenization.py
tests/entrypoints/openai/test_tokenization.py
+152
-0
tests/entrypoints/openai/test_vision.py
tests/entrypoints/openai/test_vision.py
+14
-38
tests/kernels/quant_utils.py
tests/kernels/quant_utils.py
+72
-0
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+25
-20
tests/kernels/test_attention_selector.py
tests/kernels/test_attention_selector.py
+15
-10
tests/kernels/test_blocksparse_attention.py
tests/kernels/test_blocksparse_attention.py
+16
-14
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+60
-26
tests/kernels/test_cutlass.py
tests/kernels/test_cutlass.py
+102
-46
tests/kernels/test_encoder_decoder_attn.py
tests/kernels/test_encoder_decoder_attn.py
+953
-0
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+15
-8
tests/kernels/test_flashinfer.py
tests/kernels/test_flashinfer.py
+248
-0
tests/kernels/test_fp8_quant.py
tests/kernels/test_fp8_quant.py
+87
-0
tests/kernels/test_int8_quant.py
tests/kernels/test_int8_quant.py
+10
-18
tests/kernels/test_marlin_gemm.py
tests/kernels/test_marlin_gemm.py
+316
-35
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+3
-3
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+13
-17
tests/kernels/test_sampler.py
tests/kernels/test_sampler.py
+33
-20
tests/kernels/utils.py
tests/kernels/utils.py
+920
-0
tests/lora/conftest.py
tests/lora/conftest.py
+27
-8
No files found.
Too many changes to show.
To preserve performance only
442 of 442+
files are displayed.
Plain diff
Email patch
tests/entrypoints/openai/test_serving_chat.py
View file @
e7c1b7f3
import
asyncio
import
asyncio
from
contextlib
import
suppress
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
unittest.mock
import
MagicMock
import
pytest
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
ChatCompletionRequest
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
MODEL_NAME
=
"openai-community/gpt2"
MODEL_NAME
=
"openai-community/gpt2"
CHAT_TEMPLATE
=
"Dummy chat template for testing {}"
CHAT_TEMPLATE
=
"Dummy chat template for testing {}"
pytestmark
=
pytest
.
mark
.
openai
@
dataclass
@
dataclass
class
MockModelConfig
:
class
MockModelConfig
:
...
@@ -36,11 +37,47 @@ async def _async_serving_chat_init():
...
@@ -36,11 +37,47 @@ async def _async_serving_chat_init():
model_config
,
model_config
,
served_model_names
=
[
MODEL_NAME
],
served_model_names
=
[
MODEL_NAME
],
response_role
=
"assistant"
,
response_role
=
"assistant"
,
chat_template
=
CHAT_TEMPLATE
)
chat_template
=
CHAT_TEMPLATE
,
lora_modules
=
None
,
prompt_adapters
=
None
,
request_logger
=
None
)
return
serving_completion
return
serving_completion
def
test_async_serving_chat_init
():
def
test_async_serving_chat_init
():
serving_completion
=
asyncio
.
run
(
_async_serving_chat_init
())
serving_completion
=
asyncio
.
run
(
_async_serving_chat_init
())
assert
serving_completion
.
tokenizer
is
not
None
assert
serving_completion
.
chat_template
==
CHAT_TEMPLATE
assert
serving_completion
.
tokenizer
.
chat_template
==
CHAT_TEMPLATE
def
test_serving_chat_should_set_correct_max_tokens
():
mock_engine
=
MagicMock
(
spec
=
AsyncLLMEngine
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
serving_chat
=
OpenAIServingChat
(
mock_engine
,
MockModelConfig
(),
served_model_names
=
[
MODEL_NAME
],
response_role
=
"assistant"
,
chat_template
=
CHAT_TEMPLATE
,
lora_modules
=
None
,
prompt_adapters
=
None
,
request_logger
=
None
)
req
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"what is 1+1?"
}],
guided_decoding_backend
=
"outlines"
,
)
with
suppress
(
Exception
):
asyncio
.
run
(
serving_chat
.
create_chat_completion
(
req
))
# AsyncLLMEngine.generate(inputs, sampling_params, ...)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
93
req
.
max_tokens
=
10
with
suppress
(
Exception
):
asyncio
.
run
(
serving_chat
.
create_chat_completion
(
req
))
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
10
tests/entrypoints/openai/test_tokenization.py
0 → 100644
View file @
e7c1b7f3
import
openai
# use the official client for correctness check
import
pytest
import
requests
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
...utils
import
RemoteOpenAIServer
from
.test_completion
import
zephyr_lora_added_tokens_files
# noqa: F401
from
.test_completion
import
zephyr_lora_files
# noqa: F401
# any model with a chat template should work here
MODEL_NAME
=
"HuggingFaceH4/zephyr-7b-beta"
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
(
zephyr_lora_added_tokens_files
:
str
):
# noqa: F811
args
=
[
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"bfloat16"
,
"--max-model-len"
,
"8192"
,
"--enforce-eager"
,
"--max-num-seqs"
,
"128"
,
# lora config
"--enable-lora"
,
"--lora-modules"
,
f
"zephyr-lora2=
{
zephyr_lora_added_tokens_files
}
"
,
"--max-lora-rank"
,
"64"
,
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
yield
remote_server
@
pytest
.
fixture
(
scope
=
"module"
)
def
tokenizer_name
(
model_name
:
str
,
zephyr_lora_added_tokens_files
:
str
):
# noqa: F811
return
zephyr_lora_added_tokens_files
if
(
model_name
==
"zephyr-lora2"
)
else
model_name
@
pytest
.
fixture
(
scope
=
"module"
)
def
client
(
server
):
return
server
.
get_async_client
()
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name,tokenizer_name"
,
[(
MODEL_NAME
,
MODEL_NAME
),
(
"zephyr-lora2"
,
"zephyr-lora2"
)],
indirect
=
[
"tokenizer_name"
],
)
async
def
test_tokenize_completions
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
tokenizer_name
:
str
):
base_url
=
str
(
client
.
base_url
)[:
-
3
].
strip
(
"/"
)
tokenizer
=
get_tokenizer
(
tokenizer_name
=
tokenizer_name
,
tokenizer_mode
=
"fast"
)
for
add_special
in
[
False
,
True
]:
prompt
=
"vllm1 This is a test prompt."
tokens
=
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
add_special
)
response
=
requests
.
post
(
base_url
+
"/tokenize"
,
json
=
{
"add_special_tokens"
:
add_special
,
"model"
:
model_name
,
"prompt"
:
prompt
})
response
.
raise_for_status
()
assert
response
.
json
()
==
{
"tokens"
:
tokens
,
"count"
:
len
(
tokens
),
"max_model_len"
:
8192
}
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name,tokenizer_name"
,
[(
MODEL_NAME
,
MODEL_NAME
),
(
"zephyr-lora2"
,
"zephyr-lora2"
)],
indirect
=
[
"tokenizer_name"
],
)
async
def
test_tokenize_chat
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
tokenizer_name
:
str
):
base_url
=
str
(
client
.
base_url
)[:
-
3
].
strip
(
"/"
)
tokenizer
=
get_tokenizer
(
tokenizer_name
=
tokenizer_name
,
tokenizer_mode
=
"fast"
)
for
add_generation
in
[
False
,
True
]:
for
add_special
in
[
False
,
True
]:
conversation
=
[{
"role"
:
"user"
,
"content"
:
"Hi there!"
},
{
"role"
:
"assistant"
,
"content"
:
"Nice to meet you!"
},
{
"role"
:
"user"
,
"content"
:
"Can I ask a question? vllm1"
}]
prompt
=
tokenizer
.
apply_chat_template
(
add_generation_prompt
=
add_generation
,
conversation
=
conversation
,
tokenize
=
False
)
tokens
=
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
add_special
)
response
=
requests
.
post
(
base_url
+
"/tokenize"
,
json
=
{
"add_generation_prompt"
:
add_generation
,
"add_special_tokens"
:
add_special
,
"messages"
:
conversation
,
"model"
:
model_name
})
response
.
raise_for_status
()
assert
response
.
json
()
==
{
"tokens"
:
tokens
,
"count"
:
len
(
tokens
),
"max_model_len"
:
8192
}
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name,tokenizer_name"
,
[(
MODEL_NAME
,
MODEL_NAME
),
(
"zephyr-lora2"
,
"zephyr-lora2"
)],
indirect
=
[
"tokenizer_name"
],
)
async
def
test_detokenize
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
tokenizer_name
:
str
):
base_url
=
str
(
client
.
base_url
)[:
-
3
].
strip
(
"/"
)
tokenizer
=
get_tokenizer
(
tokenizer_name
=
tokenizer_name
,
tokenizer_mode
=
"fast"
)
prompt
=
"This is a test prompt. vllm1"
tokens
=
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
False
)
print
(
f
"CALLING
{
base_url
}
FOR
{
model_name
}
"
)
response
=
requests
.
post
(
base_url
+
"/detokenize"
,
json
=
{
"model"
:
model_name
,
"tokens"
:
tokens
})
response
.
raise_for_status
()
assert
response
.
json
()
==
{
"prompt"
:
prompt
}
tests/entrypoints/
test_
openai_vision.py
→
tests/entrypoints/openai
/test
_vision.py
View file @
e7c1b7f3
from
pathlib
import
Path
from
typing
import
Dict
,
List
from
typing
import
Dict
import
openai
import
openai
import
pytest
import
pytest
import
pytest_asyncio
import
ray
from
vllm.multimodal.utils
import
ImageFetchAiohttp
,
encode_image_base64
from
vllm.multimodal.utils
import
encode_image_base64
,
fetch_image
from
..utils
import
VLLM_PATH
,
RemoteOpenAIServer
from
..
.
utils
import
VLLM_PATH
,
RemoteOpenAIServer
MODEL_NAME
=
"llava-hf/llava-1.5-7b-hf"
MODEL_NAME
=
"llava-hf/llava-1.5-7b-hf"
LLAVA_CHAT_TEMPLATE
=
(
Path
(
__file__
).
parent
.
parent
.
parent
/
LLAVA_CHAT_TEMPLATE
=
VLLM_PATH
/
"examples/template_llava.jinja"
"examples/template_llava.jinja"
)
assert
LLAVA_CHAT_TEMPLATE
.
exists
()
assert
LLAVA_CHAT_TEMPLATE
.
exists
()
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS
=
[
TEST_IMAGE_URLS
=
[
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
,
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
,
...
@@ -22,37 +19,21 @@ TEST_IMAGE_URLS = [
...
@@ -22,37 +19,21 @@ TEST_IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png"
,
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png"
,
]
]
pytestmark
=
pytest
.
mark
.
openai
@
pytest
.
fixture
(
scope
=
"module"
)
def
ray_ctx
():
ray
.
init
(
runtime_env
=
{
"working_dir"
:
VLLM_PATH
})
yield
ray
.
shutdown
()
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
():
def
server
():
return
RemoteOpenAIServer
([
args
=
[
"--model"
,
MODEL_NAME
,
"--dtype"
,
"--dtype"
,
"bfloat16"
,
"bfloat16"
,
"--max-model-len"
,
"--max-model-len"
,
"4096"
,
"4096"
,
"--enforce-eager"
,
"--enforce-eager"
,
"--image-input-type"
,
"pixel_values"
,
"--image-token-id"
,
"32000"
,
"--image-input-shape"
,
"1,3,336,336"
,
"--image-feature-size"
,
"576"
,
"--chat-template"
,
"--chat-template"
,
str
(
LLAVA_CHAT_TEMPLATE
),
str
(
LLAVA_CHAT_TEMPLATE
),
])
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
yield
remote_server
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
...
@@ -60,11 +41,10 @@ def client(server):
...
@@ -60,11 +41,10 @@ def client(server):
return
server
.
get_async_client
()
return
server
.
get_async_client
()
@
pytest
_asyncio
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
async
def
base64_encoded_image
()
->
Dict
[
str
,
str
]:
def
base64_encoded_image
()
->
Dict
[
str
,
str
]:
return
{
return
{
image_url
:
image_url
:
encode_image_base64
(
fetch_image
(
image_url
))
encode_image_base64
(
await
ImageFetchAiohttp
.
fetch_image
(
image_url
))
for
image_url
in
TEST_IMAGE_URLS
for
image_url
in
TEST_IMAGE_URLS
}
}
...
@@ -216,7 +196,7 @@ async def test_chat_streaming_image(client: openai.AsyncOpenAI,
...
@@ -216,7 +196,7 @@ async def test_chat_streaming_image(client: openai.AsyncOpenAI,
temperature
=
0.0
,
temperature
=
0.0
,
stream
=
True
,
stream
=
True
,
)
)
chunks
=
[]
chunks
:
List
[
str
]
=
[]
finish_reason_count
=
0
finish_reason_count
=
0
async
for
chunk
in
stream
:
async
for
chunk
in
stream
:
delta
=
chunk
.
choices
[
0
].
delta
delta
=
chunk
.
choices
[
0
].
delta
...
@@ -279,7 +259,3 @@ async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str,
...
@@ -279,7 +259,3 @@ async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str,
)
)
completion
=
completion
.
choices
[
0
].
text
completion
=
completion
.
choices
[
0
].
text
assert
completion
is
not
None
and
len
(
completion
)
>=
0
assert
completion
is
not
None
and
len
(
completion
)
>=
0
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
tests/kernels/quant_utils.py
0 → 100644
View file @
e7c1b7f3
from
typing
import
Optional
,
Tuple
,
Union
import
torch
def
as_float32_tensor
(
x
:
Union
[
float
,
torch
.
tensor
])
->
torch
.
tensor
:
return
torch
.
as_tensor
(
x
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
def
ref_dynamic_per_token_quant
(
x
:
torch
.
tensor
,
quant_dtype
:
torch
.
dtype
,
scale_ub
:
Optional
[
torch
.
tensor
]
=
None
)
\
->
Tuple
[
torch
.
tensor
,
torch
.
tensor
]:
assert
quant_dtype
in
[
torch
.
int8
,
torch
.
float8_e4m3fn
]
if
scale_ub
is
not
None
:
assert
quant_dtype
==
torch
.
float8_e4m3fn
qtype_traits
=
torch
.
iinfo
(
quant_dtype
)
if
quant_dtype
==
torch
.
int8
\
else
torch
.
finfo
(
quant_dtype
)
qtype_max
=
as_float32_tensor
(
qtype_traits
.
max
)
s_1
=
as_float32_tensor
(
1.0
)
s_512
=
as_float32_tensor
(
512.0
)
# For fp8, in order to match the cuda kernel output, we have to do exactly
# the same operations as in the corresponding fp8 kernel to prevent
# rounding errors.
# Compute scales
x_token_max
,
_
=
x
.
abs
().
max
(
dim
=-
1
)
x_token_max
=
as_float32_tensor
(
x_token_max
)
if
scale_ub
is
not
None
:
x_token_max
=
x_token_max
.
clamp
(
max
=
scale_ub
)
scales
=
(
x_token_max
/
qtype_max
)[:,
None
]
# Quant
if
quant_dtype
==
torch
.
int8
:
iscales
=
as_float32_tensor
(
s_1
/
scales
)
torch_out
=
as_float32_tensor
(
x
)
*
iscales
torch_out
=
torch_out
.
round
()
torch_out
=
torch_out
.
clamp
(
qtype_traits
.
min
,
qtype_traits
.
max
).
to
(
quant_dtype
)
else
:
assert
quant_dtype
==
torch
.
float8_e4m3fn
min_scaling_factor
=
s_1
/
(
qtype_max
*
s_512
)
scales
=
scales
.
clamp
(
min
=
min_scaling_factor
)
torch_out
=
as_float32_tensor
(
x
)
/
scales
torch_out
=
torch_out
.
clamp
(
qtype_traits
.
min
,
qtype_traits
.
max
).
to
(
quant_dtype
)
return
torch_out
,
scales
# The int8 version is very similar. Incorporate the int8 version, like in
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
# kernel
def
ref_dynamic_per_tensor_fp8_quant
(
x
:
torch
.
tensor
)
\
->
Tuple
[
torch
.
tensor
,
torch
.
tensor
]:
fp8_traits
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
=
as_float32_tensor
(
fp8_traits
.
max
)
one
=
as_float32_tensor
(
1.0
)
# For fp8, in order to match the cuda kernel output, we have to do exactly
# the same operations as in the corresponding fp8 kernel to prevent
# rounding errors.
x_max
=
as_float32_tensor
(
x
.
abs
().
max
())
ref_scale
=
x_max
/
fp8_max
ref_iscale
=
one
/
ref_scale
ref_out
=
(
as_float32_tensor
(
x
)
*
ref_iscale
).
clamp
(
fp8_traits
.
min
,
fp8_traits
.
max
).
to
(
dtype
=
torch
.
float8_e4m3fn
)
return
ref_out
,
ref_scale
tests/kernels/test_attention.py
View file @
e7c1b7f3
...
@@ -29,7 +29,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
...
@@ -29,7 +29,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
# FlashAttention forward only supports head dimension at most 128
# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
192
,
256
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
]
if
not
is_hip
()
else
[
64
,
80
,
96
,
112
,
128
]
]
if
not
is_hip
()
else
[
64
,
80
,
96
,
112
,
128
]
BLOCK_SIZES
=
[
16
,
32
]
BLOCK_SIZES
=
[
16
,
32
]
...
@@ -73,27 +73,27 @@ def ref_single_query_cached_kv_attention(
...
@@ -73,27 +73,27 @@ def ref_single_query_cached_kv_attention(
block_size
=
value_cache
.
shape
[
3
]
block_size
=
value_cache
.
shape
[
3
]
num_seqs
=
query
.
shape
[
0
]
num_seqs
=
query
.
shape
[
0
]
block_tables
=
block_tables
.
cpu
().
tolist
()
block_tables
_lst
=
block_tables
.
cpu
().
tolist
()
seq_lens
=
seq_lens
.
cpu
().
tolist
()
seq_lens
_lst
=
seq_lens
.
cpu
().
tolist
()
for
i
in
range
(
num_seqs
):
for
i
in
range
(
num_seqs
):
q
=
query
[
i
].
unsqueeze
(
0
)
q
=
query
[
i
].
unsqueeze
(
0
)
block_table
=
block_tables
[
i
]
block_table
=
block_tables
_lst
[
i
]
seq_len
=
int
(
seq_lens
[
i
])
seq_len
=
int
(
seq_lens
_lst
[
i
])
keys
=
[]
keys
_lst
:
List
[
torch
.
Tensor
]
=
[]
values
=
[]
values
_lst
:
List
[
torch
.
Tensor
]
=
[]
for
j
in
range
(
seq_len
):
for
j
in
range
(
seq_len
):
block_number
=
int
(
block_table
[
j
//
block_size
])
block_number
=
int
(
block_table
[
j
//
block_size
])
block_offset
=
j
%
block_size
block_offset
=
j
%
block_size
k
=
key_cache
[
block_number
,
:,
:,
block_offset
,
:]
k
=
key_cache
[
block_number
,
:,
:,
block_offset
,
:]
k
=
k
.
reshape
(
num_kv_heads
,
head_size
)
k
=
k
.
reshape
(
num_kv_heads
,
head_size
)
keys
.
append
(
k
)
keys
_lst
.
append
(
k
)
v
=
value_cache
[
block_number
,
:,
:,
block_offset
]
v
=
value_cache
[
block_number
,
:,
:,
block_offset
]
values
.
append
(
v
)
values
_lst
.
append
(
v
)
keys
=
torch
.
stack
(
keys
,
dim
=
0
)
keys
=
torch
.
stack
(
keys
_lst
,
dim
=
0
)
values
=
torch
.
stack
(
values
,
dim
=
0
)
values
=
torch
.
stack
(
values
_lst
,
dim
=
0
)
if
num_queries_per_kv
>
1
:
if
num_queries_per_kv
>
1
:
# Handle MQA and GQA
# Handle MQA and GQA
keys
=
torch
.
repeat_interleave
(
keys
,
num_queries_per_kv
,
dim
=
1
)
keys
=
torch
.
repeat_interleave
(
keys
,
num_queries_per_kv
,
dim
=
1
)
...
@@ -135,6 +135,8 @@ def test_paged_attention(
...
@@ -135,6 +135,8 @@ def test_paged_attention(
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
)
->
None
:
)
->
None
:
if
kv_cache_dtype
==
"fp8"
and
head_size
%
16
:
pytest
.
skip
()
random
.
seed
(
seed
)
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
@@ -158,14 +160,15 @@ def test_paged_attention(
...
@@ -158,14 +160,15 @@ def test_paged_attention(
# Create the block tables.
# Create the block tables.
max_num_blocks_per_seq
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
max_num_blocks_per_seq
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
block_tables
=
[]
block_tables
_lst
:
List
[
List
[
int
]]
=
[]
for
_
in
range
(
num_seqs
):
for
_
in
range
(
num_seqs
):
block_table
=
[
block_table
=
[
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
for
_
in
range
(
max_num_blocks_per_seq
)
for
_
in
range
(
max_num_blocks_per_seq
)
]
]
block_tables
.
append
(
block_table
)
block_tables_lst
.
append
(
block_table
)
block_tables
=
torch
.
tensor
(
block_tables
,
dtype
=
torch
.
int
)
block_tables
=
torch
.
tensor
(
block_tables_lst
,
dtype
=
torch
.
int
)
# Create the KV caches.
# Create the KV caches.
key_caches
,
value_caches
=
kv_cache_factory
(
NUM_BLOCKS
,
block_size
,
1
,
key_caches
,
value_caches
=
kv_cache_factory
(
NUM_BLOCKS
,
block_size
,
1
,
...
@@ -175,7 +178,7 @@ def test_paged_attention(
...
@@ -175,7 +178,7 @@ def test_paged_attention(
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# Using default kv_scale
# Using default kv_scale
kv_scale
=
1.0
k
_scale
=
v_scale
=
1.0
# Call the paged attention kernel.
# Call the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
query
)
...
@@ -193,7 +196,8 @@ def test_paged_attention(
...
@@ -193,7 +196,8 @@ def test_paged_attention(
max_seq_len
,
max_seq_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
)
)
elif
version
==
"v2"
:
elif
version
==
"v2"
:
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
...
@@ -224,7 +228,8 @@ def test_paged_attention(
...
@@ -224,7 +228,8 @@ def test_paged_attention(
max_seq_len
,
max_seq_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
)
)
else
:
else
:
raise
AssertionError
(
f
"Unknown version:
{
version
}
"
)
raise
AssertionError
(
f
"Unknown version:
{
version
}
"
)
...
@@ -284,7 +289,7 @@ def ref_multi_query_kv_attention(
...
@@ -284,7 +289,7 @@ def ref_multi_query_kv_attention(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_seqs
=
len
(
cu_seq_lens
)
-
1
num_seqs
=
len
(
cu_seq_lens
)
-
1
ref_outputs
=
[]
ref_outputs
:
List
[
torch
.
Tensor
]
=
[]
for
i
in
range
(
num_seqs
):
for
i
in
range
(
num_seqs
):
start_idx
=
cu_seq_lens
[
i
]
start_idx
=
cu_seq_lens
[
i
]
end_idx
=
cu_seq_lens
[
i
+
1
]
end_idx
=
cu_seq_lens
[
i
+
1
]
...
@@ -304,8 +309,8 @@ def ref_multi_query_kv_attention(
...
@@ -304,8 +309,8 @@ def ref_multi_query_kv_attention(
attn_mask
=
attn_mask
,
attn_mask
=
attn_mask
,
)
)
ref_outputs
.
append
(
ref_output
)
ref_outputs
.
append
(
ref_output
)
ref_output
=
torch
.
cat
(
ref_outputs
,
dim
=
0
)
return
ref_output
return
torch
.
cat
(
ref_output
s
,
dim
=
0
)
# TODO(woosuk): Add tests for USE_ALIBI=True.
# TODO(woosuk): Add tests for USE_ALIBI=True.
...
...
tests/kernels/test_attention_selector.py
View file @
e7c1b7f3
...
@@ -9,8 +9,8 @@ from vllm.attention.selector import which_attn_to_use
...
@@ -9,8 +9,8 @@ from vllm.attention.selector import which_attn_to_use
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"name"
,
[
"TORCH_SDPA"
,
"ROCM_FLASH"
,
"XFORMERS"
,
"FLASHINFER"
])
"name"
,
[
"TORCH_SDPA"
,
"ROCM_FLASH"
,
"XFORMERS"
,
"FLASHINFER"
,
"OPENVINO"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"hip"
,
"cuda"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"openvino"
,
"hip"
,
"cuda"
])
def
test_env
(
name
:
str
,
device
:
str
,
monkeypatch
):
def
test_env
(
name
:
str
,
device
:
str
,
monkeypatch
):
"""Test that the attention selector can be set via environment variable.
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
Note that we do not test FlashAttn because it is the default backend.
...
@@ -28,6 +28,11 @@ def test_env(name: str, device: str, monkeypatch):
...
@@ -28,6 +28,11 @@ def test_env(name: str, device: str, monkeypatch):
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
torch
.
float16
,
16
)
torch
.
float16
,
16
)
assert
backend
.
name
==
"ROCM_FLASH"
assert
backend
.
name
==
"ROCM_FLASH"
elif
device
==
"openvino"
:
with
patch
(
"vllm.attention.selector.is_openvino"
,
return_value
=
True
):
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
torch
.
float16
,
16
)
assert
backend
.
name
==
"OPENVINO"
else
:
else
:
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
torch
.
float16
,
16
)
torch
.
float16
,
16
)
...
@@ -42,36 +47,36 @@ def test_flash_attn(monkeypatch):
...
@@ -42,36 +47,36 @@ def test_flash_attn(monkeypatch):
# Unsupported CUDA arch
# Unsupported CUDA arch
with
patch
(
"torch.cuda.get_device_capability"
,
return_value
=
[
7
,
5
]):
with
patch
(
"torch.cuda.get_device_capability"
,
return_value
=
[
7
,
5
]):
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
None
,
16
)
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
None
,
16
)
assert
backend
.
name
!=
"
FLASH_ATTN
"
assert
backend
.
name
!=
STR_
FLASH_ATTN
_VAL
# Unsupported data type
# Unsupported data type
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float8_e4m3fn
,
None
,
16
)
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float8_e4m3fn
,
None
,
16
)
assert
backend
.
name
!=
"
FLASH_ATTN
"
assert
backend
.
name
!=
STR_
FLASH_ATTN
_VAL
# Unsupported kv cache data type
# Unsupported kv cache data type
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
"fp8"
,
16
)
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
"fp8"
,
16
)
assert
backend
.
name
!=
"
FLASH_ATTN
"
assert
backend
.
name
!=
STR_
FLASH_ATTN
_VAL
# Unsupported block size
# Unsupported block size
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
None
,
8
)
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
None
,
8
)
assert
backend
.
name
!=
"
FLASH_ATTN
"
assert
backend
.
name
!=
STR_
FLASH_ATTN
_VAL
# Unsupported sliding window
# Unsupported sliding window
backend
=
which_attn_to_use
(
8
,
16
,
8
,
1
,
torch
.
float16
,
None
,
16
)
backend
=
which_attn_to_use
(
8
,
16
,
8
,
1
,
torch
.
float16
,
None
,
16
)
assert
backend
.
name
!=
"
FLASH_ATTN
"
assert
backend
.
name
!=
STR_
FLASH_ATTN
_VAL
# flash-attn is not installed
# flash-attn is not installed
with
patch
.
dict
(
'sys.modules'
,
{
'vllm_flash_attn'
:
None
}):
with
patch
.
dict
(
'sys.modules'
,
{
'vllm_flash_attn'
:
None
}):
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
None
,
16
)
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
None
,
16
)
assert
backend
.
name
!=
"
FLASH_ATTN
"
assert
backend
.
name
!=
STR_
FLASH_ATTN
_VAL
# Unsupported head size
# Unsupported head size
backend
=
which_attn_to_use
(
8
,
17
,
8
,
None
,
torch
.
float16
,
None
,
16
)
backend
=
which_attn_to_use
(
8
,
17
,
8
,
None
,
torch
.
float16
,
None
,
16
)
assert
backend
.
name
!=
"
FLASH_ATTN
"
assert
backend
.
name
!=
STR_
FLASH_ATTN
_VAL
def
test_invalid_env
(
monkeypatch
):
def
test_invalid_env
(
monkeypatch
):
"""Throw an exception if the backend name is invalid."""
"""Throw an exception if the backend name is invalid."""
override_backend_env_variable
(
monkeypatch
,
STR_INVALID_VAL
)
override_backend_env_variable
(
monkeypatch
,
STR_INVALID_VAL
)
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
None
,
16
)
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
None
,
16
)
\ No newline at end of file
tests/kernels/test_blocksparse_attention.py
View file @
e7c1b7f3
...
@@ -77,27 +77,27 @@ def ref_single_query_cached_kv_attention(
...
@@ -77,27 +77,27 @@ def ref_single_query_cached_kv_attention(
block_size
=
value_cache
.
shape
[
3
]
block_size
=
value_cache
.
shape
[
3
]
num_seqs
=
query
.
shape
[
0
]
num_seqs
=
query
.
shape
[
0
]
block_tables
=
block_tables
.
cpu
().
tolist
()
block_tables
_lst
=
block_tables
.
cpu
().
tolist
()
seq_lens
=
seq_lens
.
cpu
().
tolist
()
seq_lens
_lst
=
seq_lens
.
cpu
().
tolist
()
for
i
in
range
(
num_seqs
):
for
i
in
range
(
num_seqs
):
q
=
query
[
i
].
unsqueeze
(
0
)
q
=
query
[
i
].
unsqueeze
(
0
)
block_table
=
block_tables
[
i
]
block_table
=
block_tables
_lst
[
i
]
seq_len
=
int
(
seq_lens
[
i
])
seq_len
=
int
(
seq_lens
_lst
[
i
])
keys
=
[]
keys
_lst
:
List
[
torch
.
Tensor
]
=
[]
values
=
[]
values
_lst
:
List
[
torch
.
Tensor
]
=
[]
for
j
in
range
(
seq_len
):
for
j
in
range
(
seq_len
):
block_number
=
int
(
block_table
[
j
//
block_size
])
block_number
=
int
(
block_table
[
j
//
block_size
])
block_offset
=
j
%
block_size
block_offset
=
j
%
block_size
k
=
key_cache
[
block_number
,
:,
:,
block_offset
,
:]
k
=
key_cache
[
block_number
,
:,
:,
block_offset
,
:]
k
=
k
.
reshape
(
num_kv_heads
,
head_size
)
k
=
k
.
reshape
(
num_kv_heads
,
head_size
)
keys
.
append
(
k
)
keys
_lst
.
append
(
k
)
v
=
value_cache
[
block_number
,
:,
:,
block_offset
]
v
=
value_cache
[
block_number
,
:,
:,
block_offset
]
values
.
append
(
v
)
values
_lst
.
append
(
v
)
keys
=
torch
.
stack
(
keys
,
dim
=
0
)
keys
=
torch
.
stack
(
keys
_lst
,
dim
=
0
)
values
=
torch
.
stack
(
values
,
dim
=
0
)
values
=
torch
.
stack
(
values
_lst
,
dim
=
0
)
if
num_queries_per_kv
>
1
:
if
num_queries_per_kv
>
1
:
# Handle MQA and GQA
# Handle MQA and GQA
keys
=
torch
.
repeat_interleave
(
keys
,
num_queries_per_kv
,
dim
=
1
)
keys
=
torch
.
repeat_interleave
(
keys
,
num_queries_per_kv
,
dim
=
1
)
...
@@ -212,7 +212,7 @@ def test_paged_attention(
...
@@ -212,7 +212,7 @@ def test_paged_attention(
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# Using default kv_scale
# Using default kv_scale
kv_scale
=
1.0
k
_scale
=
v_scale
=
1.0
tp_rank
=
0
tp_rank
=
0
# Call the paged attention kernel.
# Call the paged attention kernel.
...
@@ -231,7 +231,8 @@ def test_paged_attention(
...
@@ -231,7 +231,8 @@ def test_paged_attention(
max_seq_len
,
max_seq_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
blocksparse_local_blocks
=
blocksparse_local_blocks
,
blocksparse_local_blocks
=
blocksparse_local_blocks
,
blocksparse_vert_stride
=
blocksparse_vert_stride
,
blocksparse_vert_stride
=
blocksparse_vert_stride
,
...
@@ -267,7 +268,8 @@ def test_paged_attention(
...
@@ -267,7 +268,8 @@ def test_paged_attention(
max_seq_len
,
max_seq_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
blocksparse_local_blocks
=
blocksparse_local_blocks
,
blocksparse_local_blocks
=
blocksparse_local_blocks
,
blocksparse_vert_stride
=
blocksparse_vert_stride
,
blocksparse_vert_stride
=
blocksparse_vert_stride
,
...
@@ -432,7 +434,7 @@ def test_varlen_blocksparse_attention_prefill(
...
@@ -432,7 +434,7 @@ def test_varlen_blocksparse_attention_prefill(
value
=
torch
.
repeat_interleave
(
value
,
num_queries_per_kv
,
dim
=
1
)
value
=
torch
.
repeat_interleave
(
value
,
num_queries_per_kv
,
dim
=
1
)
ref_output
=
ref_multi_query_kv_attention
(
ref_output
=
ref_multi_query_kv_attention
(
cu_seq_lens
,
cu_seq_lens
.
tolist
()
,
query
,
query
,
key
,
key
,
value
,
value
,
...
...
tests/kernels/test_cache.py
View file @
e7c1b7f3
import
random
import
random
from
typing
import
Tuple
from
typing
import
List
,
Tuple
import
pytest
import
pytest
import
torch
import
torch
...
@@ -11,7 +11,7 @@ DTYPES = [torch.half, torch.bfloat16, torch.float]
...
@@ -11,7 +11,7 @@ DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS
=
[
42
]
# Arbitrary values for testing
NUM_TOKENS
=
[
42
]
# Arbitrary values for testing
NUM_LAYERS
=
[
1
]
# Arbitrary values for testing
NUM_LAYERS
=
[
1
]
# Arbitrary values for testing
NUM_HEADS
=
[
8
]
# Arbitrary values for testing
NUM_HEADS
=
[
8
]
# Arbitrary values for testing
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
192
,
256
]
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
# Arbitrary values for testing
# Arbitrary values for testing
...
@@ -53,6 +53,8 @@ def test_copy_blocks(
...
@@ -53,6 +53,8 @@ def test_copy_blocks(
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
device
:
str
,
device
:
str
,
)
->
None
:
)
->
None
:
if
kv_cache_dtype
==
"fp8"
and
head_size
%
16
:
pytest
.
skip
()
random
.
seed
(
seed
)
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
@@ -64,7 +66,7 @@ def test_copy_blocks(
...
@@ -64,7 +66,7 @@ def test_copy_blocks(
src_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
src_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
remainig_blocks
=
list
(
set
(
range
(
num_blocks
))
-
set
(
src_blocks
))
remainig_blocks
=
list
(
set
(
range
(
num_blocks
))
-
set
(
src_blocks
))
dst_blocks
=
random
.
sample
(
remainig_blocks
,
2
*
num_mappings
)
dst_blocks
=
random
.
sample
(
remainig_blocks
,
2
*
num_mappings
)
block_mapping
=
[]
block_mapping
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
i
in
range
(
num_mappings
):
for
i
in
range
(
num_mappings
):
src
=
src_blocks
[
i
]
src
=
src_blocks
[
i
]
dst1
=
dst_blocks
[
2
*
i
]
dst1
=
dst_blocks
[
2
*
i
]
...
@@ -125,6 +127,8 @@ def test_reshape_and_cache(
...
@@ -125,6 +127,8 @@ def test_reshape_and_cache(
device
:
str
,
device
:
str
,
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
)
->
None
:
)
->
None
:
if
kv_cache_dtype
==
"fp8"
and
head_size
%
16
:
pytest
.
skip
()
random
.
seed
(
seed
)
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
@@ -132,8 +136,8 @@ def test_reshape_and_cache(
...
@@ -132,8 +136,8 @@ def test_reshape_and_cache(
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
# Create a random slot mapping.
# Create a random slot mapping.
num_slots
=
block_size
*
num_blocks
num_slots
=
block_size
*
num_blocks
slot_mapping
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
_lst
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
_lst
,
dtype
=
torch
.
long
)
qkv
=
torch
.
randn
(
num_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
)
qkv
=
torch
.
randn
(
num_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
)
_
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
_
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
...
@@ -156,11 +160,11 @@ def test_reshape_and_cache(
...
@@ -156,11 +160,11 @@ def test_reshape_and_cache(
cloned_value_cache
=
value_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
# Using default kv_scale
# Using default kv_scale
kv_scale
=
1.0
k
_scale
=
v_scale
=
1.0
# Call the reshape_and_cache kernel.
# Call the reshape_and_cache kernel.
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
kv_scale
)
kv_cache_dtype
,
k
_scale
,
v_scale
)
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
...
@@ -171,12 +175,12 @@ def test_reshape_and_cache(
...
@@ -171,12 +175,12 @@ def test_reshape_and_cache(
# Run the reference implementation.
# Run the reference implementation.
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indicies
=
block_indicies
.
cpu
().
tolist
()
block_indicies
_lst
=
block_indicies
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets
=
slot_mapping
%
block_size
block_offsets
=
block_offsets
.
cpu
().
tolist
()
block_offsets
_lst
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
for
i
in
range
(
num_tokens
):
block_idx
=
block_indicies
[
i
]
block_idx
=
block_indicies
_lst
[
i
]
block_offset
=
block_offsets
[
i
]
block_offset
=
block_offsets
_lst
[
i
]
cloned_key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
=
reshaped_key
[
i
]
cloned_key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
=
reshaped_key
[
i
]
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
...
@@ -216,8 +220,6 @@ def test_reshape_and_cache_flash(
...
@@ -216,8 +220,6 @@ def test_reshape_and_cache_flash(
device
:
str
,
device
:
str
,
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
)
->
None
:
)
->
None
:
if
kv_cache_dtype
==
"fp8"
:
pytest
.
skip
()
random
.
seed
(
seed
)
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
...
@@ -225,8 +227,10 @@ def test_reshape_and_cache_flash(
...
@@ -225,8 +227,10 @@ def test_reshape_and_cache_flash(
# Create a random slot mapping.
# Create a random slot mapping.
num_slots
=
block_size
*
num_blocks
num_slots
=
block_size
*
num_blocks
slot_mapping
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping_lst
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
slot_mapping
=
torch
.
tensor
(
slot_mapping_lst
,
dtype
=
torch
.
long
,
device
=
device
)
qkv
=
torch
.
randn
(
num_tokens
,
qkv
=
torch
.
randn
(
num_tokens
,
3
,
3
,
...
@@ -247,29 +251,57 @@ def test_reshape_and_cache_flash(
...
@@ -247,29 +251,57 @@ def test_reshape_and_cache_flash(
dtype
,
dtype
,
device
=
device
,
device
=
device
,
)
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
key_cache
,
value_cache
=
key_caches
[
0
].
contiguous
(
),
value_caches
[
0
].
contiguous
()
del
key_caches
del
value_caches
# Clone the KV caches.
# Clone the KV caches.
cloned_key_cache
=
key_cache
.
clone
()
if
kv_cache_dtype
==
"fp8"
:
cloned_value_cache
=
value_cache
.
clone
()
cloned_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_key_cache
,
key_cache
)
cloned_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_value_cache
,
value_cache
)
else
:
cloned_key_cache
=
key_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
# Using default kv_scale
k_scale
=
v_scale
=
1.0
# Call the reshape_and_cache kernel.
# Call the reshape_and_cache kernel.
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
)
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
)
# Run the reference implementation.
# Run the reference implementation.
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
'
floor
'
)
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"
floor
"
)
block_indicies
=
block_indicies
.
cpu
().
tolist
()
block_indicies
_lst
=
block_indicies
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets
=
slot_mapping
%
block_size
block_offsets
=
block_offsets
.
cpu
().
tolist
()
block_offsets
_lst
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
for
i
in
range
(
num_tokens
):
block_idx
=
block_indicies
[
i
]
block_idx
=
block_indicies
_lst
[
i
]
block_offset
=
block_offsets
[
i
]
block_offset
=
block_offsets
_lst
[
i
]
cloned_key_cache
[
block_idx
,
block_offset
,
:,
:]
=
key
[
i
]
cloned_key_cache
[
block_idx
,
block_offset
,
:,
:]
=
key
[
i
]
cloned_value_cache
[
block_idx
,
block_offset
,
:,
:]
=
value
[
i
]
cloned_value_cache
[
block_idx
,
block_offset
,
:,
:]
=
value
[
i
]
assert
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
if
kv_cache_dtype
==
"fp8"
:
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
assert
torch
.
allclose
(
result_key_cache
,
cloned_key_cache
,
atol
=
0.001
,
rtol
=
0.1
)
assert
torch
.
allclose
(
result_value_cache
,
cloned_value_cache
,
atol
=
0.001
,
rtol
=
0.1
)
else
:
assert
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
@
pytest
.
mark
.
parametrize
(
"direction"
,
COPYING_DIRECTION
)
@
pytest
.
mark
.
parametrize
(
"direction"
,
COPYING_DIRECTION
)
...
@@ -298,6 +330,8 @@ def test_swap_blocks(
...
@@ -298,6 +330,8 @@ def test_swap_blocks(
)
->
None
:
)
->
None
:
if
kv_cache_dtype
==
"fp8"
and
"cpu"
in
direction
:
if
kv_cache_dtype
==
"fp8"
and
"cpu"
in
direction
:
pytest
.
skip
()
pytest
.
skip
()
if
kv_cache_dtype
==
"fp8"
and
head_size
%
16
:
pytest
.
skip
()
random
.
seed
(
seed
)
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
...
tests/kernels/test_cutlass.py
View file @
e7c1b7f3
...
@@ -2,36 +2,53 @@
...
@@ -2,36 +2,53 @@
Run `pytest tests/kernels/test_cutlass.py`.
Run `pytest tests/kernels/test_cutlass.py`.
"""
"""
from
typing
import
Type
from
typing
import
Optional
,
Type
import
pytest
import
pytest
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
CUDA_DEVICES
=
[
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
]
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
def
to_fp8
(
tensor
:
torch
.
t
ensor
):
def
to_fp8
(
tensor
:
torch
.
T
ensor
):
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
return
torch
.
round
(
tensor
.
clamp
(
return
torch
.
round
(
tensor
.
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)).
to
(
dtype
=
torch
.
float8_e4m3fn
)
min
=
finfo
.
min
,
max
=
finfo
.
max
)).
to
(
dtype
=
torch
.
float8_e4m3fn
)
def
to_int8
(
tensor
:
torch
.
t
ensor
):
def
to_int8
(
tensor
:
torch
.
T
ensor
):
return
torch
.
round
(
tensor
.
clamp
(
min
=-
128
,
max
=
127
)).
to
(
dtype
=
torch
.
int8
)
return
torch
.
round
(
tensor
.
clamp
(
min
=-
128
,
max
=
127
)).
to
(
dtype
=
torch
.
int8
)
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
Type
[
torch
.
dtype
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
output
=
(
scale_a
*
(
scale_b
*
(
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
float32
),
b
.
to
(
dtype
=
torch
.
float32
))))).
to
(
out_dtype
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
def
cutlass_fp8_gemm_helper
(
m
:
int
,
def
cutlass_fp8_gemm_helper
(
m
:
int
,
n
:
int
,
n
:
int
,
k
:
int
,
k
:
int
,
per_token_act_quant
:
bool
,
per_token_act_quant
:
bool
,
per_out_channel_weight_quant
:
bool
,
per_out_channel_weight_quant
:
bool
,
use_bias
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
bfloat16
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
bfloat16
,
device
:
str
=
"cuda"
):
device
:
str
=
"cuda"
):
# Test for a cutlass kernel with per-token activation quantization
# Test for a cutlass kernel with per-token activation quantization
...
@@ -42,16 +59,19 @@ def cutlass_fp8_gemm_helper(m: int,
...
@@ -42,16 +59,19 @@ def cutlass_fp8_gemm_helper(m: int,
m_a_scales
=
m
if
per_token_act_quant
else
1
m_a_scales
=
m
if
per_token_act_quant
else
1
n_b_scales
=
n
if
per_out_channel_weight_quant
else
1
n_b_scales
=
n
if
per_out_channel_weight_quant
else
1
scale_a
=
(
torch
.
randn
(
scale_a
=
(
torch
.
randn
((
m_a_scales
,
1
),
device
=
device
,
(
m_a_scales
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
/
10
)
dtype
=
torch
.
float32
))
scale_b
=
(
torch
.
randn
(
scale_b
=
(
torch
.
randn
((
1
,
n_b_scales
),
device
=
device
,
(
1
,
n_b_scales
),
device
=
device
,
dtype
=
torch
.
float32
)
/
10
)
dtype
=
torch
.
float32
))
if
use_bias
:
bias
=
torch
.
rand
((
n
,
),
device
=
device
,
dtype
=
out_dtype
)
*
10
else
:
bias
=
None
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
)
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
torch
.
mm
(
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
),
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
scale_b
*
b
.
to
(
dtype
=
torch
.
float32
)).
to
(
out_dtype
)
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
1
e-
1
)
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
5
e-
2
)
def
cutlass_int8_gemm_helper
(
m
:
int
,
def
cutlass_int8_gemm_helper
(
m
:
int
,
...
@@ -59,6 +79,7 @@ def cutlass_int8_gemm_helper(m: int,
...
@@ -59,6 +79,7 @@ def cutlass_int8_gemm_helper(m: int,
k
:
int
,
k
:
int
,
per_token_act_quant
:
bool
,
per_token_act_quant
:
bool
,
per_out_channel_weight_quant
:
bool
,
per_out_channel_weight_quant
:
bool
,
use_bias
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
bfloat16
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
bfloat16
,
device
:
str
=
"cuda"
):
device
:
str
=
"cuda"
):
# Test for a cutlass kernel with per-token activation quantization
# Test for a cutlass kernel with per-token activation quantization
...
@@ -69,79 +90,106 @@ def cutlass_int8_gemm_helper(m: int,
...
@@ -69,79 +90,106 @@ def cutlass_int8_gemm_helper(m: int,
m_a_scales
=
m
if
per_token_act_quant
else
1
m_a_scales
=
m
if
per_token_act_quant
else
1
n_b_scales
=
n
if
per_out_channel_weight_quant
else
1
n_b_scales
=
n
if
per_out_channel_weight_quant
else
1
scale_a
=
(
torch
.
randn
(
scale_a
=
(
torch
.
randn
(
(
m_a_scales
,
1
),
device
=
device
,
(
m_a_scales
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
/
10
)
dtype
=
torch
.
float32
))
scale_b
=
(
torch
.
randn
(
scale_b
=
(
torch
.
randn
(
(
1
,
n_b_scales
),
device
=
device
,
(
1
,
n_b_scales
),
device
=
device
,
dtype
=
torch
.
float32
)
/
10
)
dtype
=
torch
.
float32
))
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
)
if
use_bias
:
baseline
=
torch
.
mm
(
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
),
bias
=
torch
.
rand
((
n
,
),
device
=
device
,
dtype
=
out_dtype
)
*
10
scale_b
*
else
:
b
.
to
(
dtype
=
torch
.
float32
)).
to
(
dtype
=
out_dtype
)
bias
=
None
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
512
,
222
,
100
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
222
,
100
,
33
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
4096
,
8192
,
16384
,
24576
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
496
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
496
,
1024
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
capability
<
89
,
@
pytest
.
mark
.
skipif
(
capability
<
89
,
reason
=
"FP8 is not supported on this GPU type."
)
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
def
test_cutlass_fp8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
):
per_out_ch
:
bool
,
use_bias
:
bool
):
cutlass_fp8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
)
cutlass_fp8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
8192
,
16384
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
496
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
496
,
1024
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
def
test_cutlass_int8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
):
per_out_ch
:
bool
,
use_bias
:
bool
):
cutlass_int8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
)
cutlass_int8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
def
test_cutlass_int8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
]):
out_dtype
:
Type
[
torch
.
dtype
],
cutlass_int8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
use_bias
:
bool
):
out_dtype
)
cutlass_int8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
use_bias
,
out_dtype
=
out_dtype
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
capability
<
89
,
@
pytest
.
mark
.
skipif
(
capability
<
89
,
reason
=
"FP8 is not supported on this GPU type."
)
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
def
test_cutlass_fp8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
]):
out_dtype
:
Type
[
torch
.
dtype
],
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
use_bias
:
bool
):
out_dtype
)
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
use_bias
,
out_dtype
=
out_dtype
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
skipif
(
capability
<
89
,
@
pytest
.
mark
.
skipif
(
capability
<
89
,
reason
=
"FP8 is not supported on this GPU type."
)
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_devices
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
def
test_cutlass_fp8_gemm_devices
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
device
:
str
):
use_bias
:
bool
,
device
:
str
):
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
use_bias
,
torch
.
bfloat16
,
device
)
torch
.
bfloat16
,
device
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_cutlass_int8_gemm_devices
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
def
test_cutlass_int8_gemm_devices
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
device
:
str
):
use_bias
:
bool
,
device
:
str
):
cutlass_int8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
cutlass_int8_gemm_helper
(
512
,
torch
.
bfloat16
,
device
)
512
,
512
,
per_act_token
,
per_out_ch
,
use_bias
,
out_dtype
=
torch
.
bfloat16
,
device
=
device
)
# For the following two tests:
# For the following two tests:
...
@@ -151,20 +199,26 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
...
@@ -151,20 +199,26 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
# kernel must handle any M thrown at it.
# kernel must handle any M thrown at it.
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
capability
<
89
,
@
pytest
.
mark
.
skipif
(
capability
<
89
,
reason
=
"FP8 is not supported on this GPU type."
)
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
):
def
test_cutlass_fp8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
):
for
nk
in
range
(
32
,
128
,
32
):
for
nk
in
range
(
32
,
128
,
32
):
for
m
in
range
(
1
,
128
):
for
m
in
range
(
1
,
128
):
cutlass_fp8_gemm_helper
(
m
,
nk
,
nk
,
per_act_token
,
per_out_ch
)
cutlass_fp8_gemm_helper
(
m
,
nk
,
nk
,
per_act_token
,
per_out_ch
,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
def
test_cutlass_int8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
):
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
):
for
nk
in
range
(
32
,
128
,
32
):
for
nk
in
range
(
32
,
128
,
32
):
for
m
in
range
(
1
,
128
):
for
m
in
range
(
1
,
128
):
cutlass_int8_gemm_helper
(
m
,
nk
,
nk
,
per_act_token
,
per_out_ch
)
cutlass_int8_gemm_helper
(
m
,
nk
,
nk
,
per_act_token
,
per_out_ch
,
use_bias
)
# Test working with a subset of A and B
# Test working with a subset of A and B
...
@@ -185,9 +239,11 @@ def test_cutlass_subset():
...
@@ -185,9 +239,11 @@ def test_cutlass_subset():
scale_a
,
scale_a
,
scale_b
,
scale_b
,
out_dtype
=
torch
.
bfloat16
)
out_dtype
=
torch
.
bfloat16
)
baseline
=
torch
.
mm
(
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
),
baseline
=
baseline_scaled_mm
(
a
,
scale_b
*
b
,
b
.
to
(
dtype
=
torch
.
float32
)).
to
(
dtype
=
torch
.
bfloat16
)
scale_a
,
scale_b
,
out_dtype
=
torch
.
bfloat16
)
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
...
...
tests/kernels/test_encoder_decoder_attn.py
0 → 100644
View file @
e7c1b7f3
"""
Tests:
* E2E test of Encoder attention + Decoder self-attention +
Encoder/decoder cross-attention (collectively
"encoder/decoder attention")
* Confirm enc/dec models will fail for chunked prefill
* Confirm enc/dec models will fail for prefix caching
"""
from
typing
import
NamedTuple
,
Optional
import
pytest
import
torch
from
tests.kernels.utils
import
*
from
tests.kernels.utils
import
make_causal_mask
,
maybe_make_long_tensor
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionBackend
,
AttentionType
from
vllm.attention.backends.utils
import
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from
vllm.utils
import
is_hip
HEAD_SIZES
=
[
64
,
256
]
NUM_HEADS
=
[
1
,
16
]
BATCH_SIZES
=
[
1
,
16
]
BLOCK_SIZES
=
[
16
]
BACKEND_NAMES
=
[
STR_XFORMERS_ATTN_VAL
]
CUDA_DEVICE
=
"cuda:0"
MAX_DEC_SEQ_LENS
=
[
128
]
MAX_ENC_SEQ_LENS
=
[
128
]
# Narrow teest-cases for unsupported-scenario
# tests
HEAD_SIZES_FOR_UNSUPP
=
[
HEAD_SIZES
[
0
]]
class
TestPoint
(
NamedTuple
):
"""
Encapsulates the attributes which define a single invocation
of the test_e2e_enc_dec_attn() test
Attributes:
num_heads: The number of heads in the model.
head_size: Head dimension
backend_name: Name of the backend framework used.
batch_size: Number of samples per batch.
block_size: Size of each block of data processed.
max_dec_seq_len: Maximum sequence length for the decoder.
max_enc_seq_len: Maximum sequence length for the encoder.
num_blocks: Number of blocks in the model.
"""
num_heads
:
int
head_size
:
int
backend_name
:
str
batch_size
:
int
block_size
:
int
max_dec_seq_len
:
int
max_enc_seq_len
:
int
num_blocks
:
int
class
TestResources
(
NamedTuple
):
'''
Encapsulates key components for performing an
encoder/decoder attention test
Note that
(1) attn automatically selects an attention backend
based on platform info & a set of canned
heuristics
(2) attn_backend is thus *not the same backend
instance* used by attn, but rather it is
intended to be a
*different instance* of the *same backend class*;
it is assumed that the user of TestResources
will leverage attn_backend for the purpose of
constructing backend-compatible attention
metadata instances
Attributes:
* scale: 1/sqrt(d) scale factor for attn
* attn_backend: implementatino of abstraction
attention interface using
a particular kernel library
i.e. XFormers
* attn: Attention layer instance
* kv_cache: shared key/value cache for all attention
'''
scale
:
float
attn_backend
:
AttentionBackend
attn
:
Attention
kv_cache
:
torch
.
Tensor
def
_make_test_resources
(
test_pt
:
TestPoint
,
)
->
TestResources
:
'''
Build key components for performing encoder/decoder attention test.
Note that
(1) The Attention instance constructed here, automatically selects
an attention backend class based on platform info & a set of canned
heuristics, so
(2) The attention backend instance constructed here is thus *not
the same backend instance* used by attn, but rather it is
intended to be a *different instance* of the *same backend class*;
therefore,
(3) This function requires that test_pt.backend_name matches the backend
class that Attention will automatically select when it is constructed.
Arguments:
* test_pt: TestPoint data structure; this function relies on the
following fields: num_heads, head_size, num_blocks,
block_size, backend_name
Returns:
* TestResources data structure.
'''
scale
=
float
(
1.0
/
(
test_pt
.
head_size
**
0.5
))
attn_backend
=
make_backend
(
test_pt
.
backend_name
)
attn
=
Attention
(
test_pt
.
num_heads
,
test_pt
.
head_size
,
scale
=
scale
,
)
if
test_pt
.
num_blocks
is
None
or
test_pt
.
num_heads
is
None
:
# Caller does not require a KV cache
return
TestResources
(
scale
,
attn_backend
,
attn
,
None
)
# Construct KV cache
kv_cache
=
make_kv_cache
(
test_pt
.
num_blocks
,
test_pt
.
num_heads
,
test_pt
.
head_size
,
test_pt
.
block_size
,
device
=
CUDA_DEVICE
)
return
TestResources
(
scale
,
attn_backend
,
attn
,
kv_cache
)
def
_encoder_attn_setup
(
test_pt
:
TestPoint
,
test_rsrcs
:
TestResources
,
)
->
PhaseTestParameters
:
'''
Set up test vectors & data structures for encoder attention test.
A triplet of synthetic query/key/value tensors are constructed.
Given this is an encoder attention test, the key & value
sequences will have the same length as the corresponding queries.
The query/key/value tensors are passed to an ideal reference
self-attention implementation to generate an ideal output tensor.
Encoder inference does not populate the KV cache, therefore
no KV cache memory mapping is constructed
Arguments:
* test_pt: TestPoint data structure; this function relies on the
following fields: batch_size, num_heads, head_size,
block_size, max_q_seq_len
* test_rsrcs: TestResources data structure; this function relies on the
scale field
Returns:
* PhaseTestParameters data structure comprising (1) packed query/key/value
tensors, (2) the ideal output of attention computed using a naive
implementation, and (3) KVCache field set to None
'''
(
num_heads
,
head_size
,
_
,
batch_size
,
_
,
_
,
max_q_seq_len
,
_
,
)
=
test_pt
scale
=
test_rsrcs
.
scale
max_kv_seq_len
=
max_q_seq_len
# Make test tensors
qkv_in
,
_
,
_
=
make_qkv
(
batch_size
,
max_q_seq_len
,
max_kv_seq_len
,
num_heads
,
head_size
,
attn_type
=
AttentionType
.
ENCODER
,
device
=
CUDA_DEVICE
)
# Compute correct answer using naive non-causal attention
# implementation
ideal_output
=
ref_masked_attention
(
qkv_in
.
query
,
qkv_in
.
key
,
qkv_in
.
value
,
scale
=
scale
,
q_seq_lens
=
qkv_in
.
q_seq_lens
,
kv_seq_lens
=
qkv_in
.
kv_seq_lens
)
packed_ideal_output
,
_
=
pack_tensor
(
ideal_output
,
qkv_in
.
q_seq_lens
,
device
=
CUDA_DEVICE
)
packed_qkv
=
pack_qkv
(
qkv_in
,
device
=
CUDA_DEVICE
)
return
PhaseTestParameters
(
PackedQKVO
(
packed_qkv
,
packed_ideal_output
),
None
# No KV cache
)
def
_decoder_attn_setup
(
test_pt
:
TestPoint
,
test_rsrcs
:
TestResources
,
block_base_addr
:
int
=
0
,
)
->
Tuple
[
QKVInputs
,
PhaseTestParameters
,
PhaseTestParameters
,
int
]:
'''
Set up test vectors & data structures for self-attention test.
A triplet of synthetic query/key/value tensors are constructed ("baseline"
query/key/value). Given this is a self-attention test, the key & value
sequences will have the same length as the corresponding queries.
"Prefill" query/key/value tensors are derived by masking out the last value
in each baseline query/key/value. These tensors are used to test prefill &
populate KV cache for a subsequent decode test.
"Decode" query/key/value tensors are derived by extracting *only* the last
value from each baseline query/key/value (i.e. complement of the prefill
tensors.) These tensors are used to test decode, conditional on the kv cache
being populated during the prefill test.
The baseline query/key/value tensors are passed to an ideal reference
self-attention implementation to generate a "Baseline" ideal output tensor.
This tensor is split into the "Prefill" ideal output tensor (all but the
last element of each output sequence) and the "Decode" ideal output tensor
(*only* the last element of each output sequence); the "Prefill" and
"Decode" ideal output tensors can be used to validate the prefill and decode
test results, respectively.
This function also constructs the self-attention KV cache memory mapping
(slot mapping and block table), ensuring that the block table starts at
block_base_addr
Arguments:
* test_pt: TestPoint data structure; this function relies on the
following fields: batch_size, num_heads, head_size,
block_size, max_q_seq_len
* test_rsrcs: TestResources data structure; this function relies on the
scale field
* block_base_addr: decoder self-attention block-table base address
Returns:
* qkv: Unpacked (batch_size x padded_seq_len x num_heads x
head_size) query/key/value tensors
* Prefill-phase decoder self-attention PhaseTestParameters data structure,
including (1) packed (number_of_tokens x num_heads x head_size)
query/key/value tensors along with (2) ideal attention output
computed using a naive implementation, and (3) memory-mapping data
structures appropriate for prefill phase.
* Decode-phase decoder self-attention PhaseTestParameters data structure,
including (1) packed (number_of_tokens x num_heads x head_size)
query/key/value tensors along with (2) ideal attention output
computed using a naive implementation, and (3) memory-mapping data
structures appropriate for decode phase.
* max_block_idx: max physical address in decoder self-attention block-table
(intended to be used as the base address for the encoder/
decoder cross-attention block-table, which is not
constructed in this function)
'''
(
num_heads
,
head_size
,
_
,
batch_size
,
block_size
,
max_q_seq_len
,
_
,
_
,
)
=
test_pt
scale
=
test_rsrcs
.
scale
max_kv_seq_len
=
max_q_seq_len
# Build test tensors
(
qkv
,
prefill_qkv
,
decode_qkv
,
)
=
make_qkv
(
batch_size
,
max_q_seq_len
,
max_kv_seq_len
,
num_heads
,
head_size
,
attn_type
=
AttentionType
.
DECODER
,
device
=
CUDA_DEVICE
)
# Compute correct answer using naive attention implementation
# with causal attention mask
causal_mask
=
make_causal_mask
(
max_q_seq_len
,
max_kv_seq_len
).
to
(
CUDA_DEVICE
)
ideal_output
=
ref_masked_attention
(
qkv
.
query
,
qkv
.
key
,
qkv
.
value
,
scale
=
scale
,
custom_mask
=
causal_mask
,
q_seq_lens
=
qkv
.
q_seq_lens
,
kv_seq_lens
=
qkv
.
kv_seq_lens
)
# Split out the prefill- & decode-phase ideal answers & pack them
prefill_ideal_output
=
torch
.
zeros_like
(
ideal_output
)
decode_ideal_output
=
torch
.
zeros_like
(
ideal_output
[:,
0
:
1
])
for
bdx
,
prefill_q_seq_len
in
enumerate
(
prefill_qkv
.
q_seq_lens
):
prefill_ideal_output
[
bdx
,
:
prefill_q_seq_len
]
=
ideal_output
[
bdx
,
:
prefill_q_seq_len
]
decode_ideal_output
[
bdx
,
:]
=
ideal_output
[
bdx
,
prefill_q_seq_len
:(
prefill_q_seq_len
+
1
)]
prefill_packed_ideal_output
,
_
=
pack_tensor
(
prefill_ideal_output
,
prefill_qkv
.
q_seq_lens
,
device
=
CUDA_DEVICE
)
decode_packed_ideal_output
,
_
=
pack_tensor
(
decode_ideal_output
,
[
1
for
_
in
range
(
batch_size
)],
device
=
CUDA_DEVICE
)
# Build prefill- & decode-phase data structures
# for decoder self-attention. Block tables and
# slot mapping must be in a format compatible
# with KV caching & attention kernels
#
# Prefill-phase:
#
# * Empty block-tables tensor
# * Slot-mapping with entries for prompt tokens
#
# Decode-phase:
# * Block-tables tensor with minimum number of blocks
# required by total num. tokens in the entirety of all sequences
# (including both prefill & decode)
# * Slot-mapping with entries for tokens that will be decoded in the
# current decode iteration
#
# Note: the format described above is simply mirroring what ModelRunner
# produces
prefill_block_tables
=
make_empty_block_tables_tensor
(
device
=
CUDA_DEVICE
)
(
decode_block_tables
,
slot_mapping_list
,
max_block_idx
,
)
=
make_block_tables_slot_mapping
(
block_size
,
qkv
.
q_seq_lens
,
device
=
CUDA_DEVICE
,
block_base_addr
=
block_base_addr
)
(
prefill_slot_mapping
,
decode_slot_mapping
,
)
=
split_slot_mapping
(
slot_mapping_list
,
qkv
.
q_seq_lens
,
device
=
CUDA_DEVICE
)
prefill_pckd_qkv
=
pack_qkv
(
prefill_qkv
,
device
=
CUDA_DEVICE
)
decode_pckd_qkv
=
pack_qkv
(
decode_qkv
,
device
=
CUDA_DEVICE
)
return
(
qkv
,
PhaseTestParameters
(
# Prefill test params
PackedQKVO
(
prefill_pckd_qkv
,
prefill_packed_ideal_output
),
KVMemoryMap
(
prefill_block_tables
,
prefill_slot_mapping
)),
PhaseTestParameters
(
# Decode test params
PackedQKVO
(
decode_pckd_qkv
,
decode_packed_ideal_output
),
KVMemoryMap
(
decode_block_tables
,
decode_slot_mapping
)),
max_block_idx
)
def
_enc_dec_cross_attn_setup_reuses_query
(
decoder_qkv
:
QKVInputs
,
encoder_test_params
:
PhaseTestParameters
,
prefill_decoder_phase_test_params
:
PhaseTestParameters
,
test_pt
:
TestPoint
,
test_rsrcs
:
TestResources
,
block_base_addr
:
int
=
0
,
)
->
Tuple
[
PhaseTestParameters
,
PhaseTestParameters
]:
'''
Set up test vectors & data structures for cross-attention test.
A triplet of synthetic cross-attention key/value tensors are constructed
("baseline" key/value). Given this is a cross-attention test, we assume
query tensors were already synthesized for a prior self-attention test and
will be reused for cross-attention. The key & value sequences generated here
may have a different length than the corresponding queries (as is often
the case for cross-attention between decoder and encoder sequences.)
Cross attention key & value tensors do not grow during autoregressive
inference; thus this function obtains a single key/value pair suitable for
both prefill and decode.
The "baseline" query tensor is received as an argument. The "baseline"
query/key/value tensors are passed to an ideal reference cross-attention
implementation to generate a "baseline" ideal output tensor. This tensor is
split into the "Prefill" ideal output tensor (all but the last element of
each output sequence) and the "Decode" ideal output tensor (*only* the last
element of each output sequence); the "Prefill" and "Decode" ideal output
tensors can be used to validate the prefill and decode test results,
respectively.
This function also constructs the cross-attention KV cache memory mapping
(slot mapping and block table), ensuring that the block table starts at
block_base_addr.
Arguments:
* decoder_qkv: pre-existing unpacked (batch_size x padded_seq_len x
num_heads x head_size) decoder self-attention inputs;
this function relies on the query and q_seq_lens
fields
* encoder_test_params: PhaseTestParameters data structure which was
used for encoder inference; KV cache field
is not used by this function
* prefill_decoder_phase_test_params: PhaseTestParameters data structure
used for prefill-phase decoder
self-attention; all fields
including KV cache required
* test_pt: TestPoint data structure; this function relies on the
following fields: batch_size, num_heads, head_size,
block_size, max_q_seq_len
* test_rsrcs: TestResources data structure; this function relies on the
scale field
* block_base_addr: decoder self-attention block-table base address
Returns:
* Prefill-phase encoder/decoder cross-attention PhaseTestParameters data
structure, including (1) packed
(number_of_tokens x num_heads x head_size) query/key/value tensors
along with (2) ideal attention output computed using a
naive implementation, and (3) memory-mapping data structures appropriate
for prefill phase.
* Decode-phase encoder/decoder cross-attention PhaseTestParameters data
structure, including (1) packed
(number_of_tokens x num_heads x head_size) query/key/value tensors
along with (2) ideal attention output computed using a
naive implementation, and (3) memory-mapping data structures appropriate
for decode phase.
'''
assert
encoder_test_params
.
packed_qkvo
.
packed_qkv
is
not
None
assert
prefill_decoder_phase_test_params
.
packed_qkvo
.
packed_qkv
is
not
None
(
num_heads
,
head_size
,
_
,
batch_size
,
block_size
,
max_decoder_seq_len
,
max_encoder_seq_len
,
_
,
)
=
test_pt
scale
=
test_rsrcs
.
scale
decoder_query
=
decoder_qkv
.
query
decoder_seq_lens
=
decoder_qkv
.
q_seq_lens
encoder_seq_lens
=
encoder_test_params
.
packed_qkvo
.
packed_qkv
.
q_seq_lens
prefill_q_seq_lens
=
(
prefill_decoder_phase_test_params
.
packed_qkvo
.
packed_qkv
.
q_seq_lens
)
assert
prefill_q_seq_lens
is
not
None
(
cross_kv
,
_
,
_
,
)
=
make_qkv
(
batch_size
,
max_decoder_seq_len
,
max_encoder_seq_len
,
num_heads
,
head_size
,
force_kv_seq_lens
=
encoder_seq_lens
,
attn_type
=
AttentionType
.
ENCODER_DECODER
,
device
=
CUDA_DEVICE
)
ideal_output
=
ref_masked_attention
(
decoder_query
,
cross_kv
.
key
,
cross_kv
.
value
,
scale
=
scale
,
q_seq_lens
=
decoder_seq_lens
,
kv_seq_lens
=
cross_kv
.
kv_seq_lens
)
prefill_ideal_output
=
torch
.
zeros_like
(
ideal_output
)
decode_ideal_output
=
torch
.
zeros_like
(
ideal_output
[:,
0
:
1
])
for
bdx
,
prefill_q_seq_len
in
enumerate
(
prefill_q_seq_lens
):
prefill_ideal_output
[
bdx
,
:
prefill_q_seq_len
]
=
ideal_output
[
bdx
,
:
prefill_q_seq_len
]
decode_ideal_output
[
bdx
,
:]
=
ideal_output
[
bdx
,
prefill_q_seq_len
:(
prefill_q_seq_len
+
1
)]
prefill_packed_ideal_output
,
_
=
pack_tensor
(
prefill_ideal_output
,
prefill_q_seq_lens
,
device
=
CUDA_DEVICE
)
decode_packed_ideal_output
,
_
=
pack_tensor
(
decode_ideal_output
,
[
1
for
_
in
range
(
batch_size
)],
device
=
CUDA_DEVICE
)
# Build prefill- & decode-phase data structures
# for encoder/decoder cross-attention. Block tables and
# slot mapping must be in a format compatible
# with KV caching & attention kernels
#
# Whereas decoder self-attention extracts relationships between
# equal-length Q/K/V sequences, which mutually grow in length
# with each decoded token, cross-attention relates the Q sequence
# - which grows with each new decoded token - to fixed-length
# K and V sequences derived from the encoder hidden states.
#
# Prefill-phase:
#
# * Empty block-tables tensor
# * Slot-mapping with as many entries as there are tokens in the encoder
# prompt.
#
# Decode-phase:
# * Block-tables tensor with minimum number of blocks to
# accommodate K & V tensors which are equal in lnegth
# to the encoder prompt length
# * Empty slot-mapping tensor (since K & V are fixed in size,
# new decoded tokens are not KV-cached and require no slot-
# mapping)
#
# Note: the format above is simply an extension of what ModelRunner
# produces for decoder-only models
prefill_block_tables
=
make_empty_block_tables_tensor
(
device
=
CUDA_DEVICE
)
decode_slot_mapping
=
make_empty_slot_mapping_tensor
(
device
=
CUDA_DEVICE
)
(
decode_block_tables
,
prefill_slot_mapping_list
,
_
,
)
=
make_block_tables_slot_mapping
(
block_size
,
cross_kv
.
kv_seq_lens
,
block_base_addr
=
block_base_addr
,
device
=
CUDA_DEVICE
)
prefill_slot_mapping
=
maybe_make_long_tensor
(
prefill_slot_mapping_list
,
device
=
CUDA_DEVICE
)
# Packed key/value (query is already provided)
packed_cross_kv
=
pack_qkv
(
cross_kv
,
device
=
CUDA_DEVICE
)
return
(
PhaseTestParameters
(
# Prefill-phase test params
PackedQKVO
(
packed_cross_kv
,
prefill_packed_ideal_output
),
KVMemoryMap
(
prefill_block_tables
,
prefill_slot_mapping
)),
PhaseTestParameters
(
# Decode-phase test params
PackedQKVO
(
None
,
decode_packed_ideal_output
),
KVMemoryMap
(
decode_block_tables
,
decode_slot_mapping
)))
def
_run_encoder_attention_test
(
attn
:
Attention
,
encoder_test_params
:
PhaseTestParameters
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
'''
Run encoder attention.
attn.forward() is passed attn_type=AttentionType.ENCODER in order
to configure the kernel invocation for encoder attention
Requires attn_metadata.num_decode_tokens == 0
(There is no encoder execution in the decode-phase)
Arguments:
* attn: Attention wrapper instance
* encoder_test_params: encoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
query/key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention
Returns:
* Attention.forward() applied to packed {query,key,value} and
& attn_metadata
'''
assert
attn_metadata
.
num_decode_tokens
==
0
attn_type
=
AttentionType
.
ENCODER
packed_qkv
=
encoder_test_params
.
packed_qkvo
.
packed_qkv
assert
packed_qkv
is
not
None
return
attn
.
forward
(
packed_qkv
.
query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
None
,
attn_metadata
,
attn_type
=
attn_type
)
def
_run_decoder_self_attention_test
(
test_rsrcs
:
TestResources
,
decoder_test_params
:
PhaseTestParameters
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
'''
Run decoder self-attention test.
attn.forward() is passed attn_type=AttentionType.DECODER
in order to configure the kernel invocation for decoder self-attention.
Arguments:
* test_rsrcs: TestResources instance; this function relies on the kv_cache
and attn (Attention wrapper instance) fields
* decoder_test_params: decoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
query/key/value fields
* attn_metadata: attention metadata for decoder-self attention
(contains KV cache memory-mapping)
Returns:
* Attention.forward() applied to packed_{query,key,value}, kv_cache
& attn_metadata
'''
attn_type
=
AttentionType
.
DECODER
attn
=
test_rsrcs
.
attn
kv_cache
=
test_rsrcs
.
kv_cache
packed_qkv
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
assert
packed_qkv
is
not
None
return
attn
.
forward
(
packed_qkv
.
query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
kv_cache
,
attn_metadata
,
attn_type
=
attn_type
)
def
_run_encoder_decoder_cross_attention_test
(
test_rsrcs
:
TestResources
,
decoder_test_params
:
PhaseTestParameters
,
cross_test_params
:
Optional
[
PhaseTestParameters
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
'''
Run encoder/decoder cross-attention test.
Via PhaseTestParameters data structures, consumes the same query utilized
for decoder self-attention, plus a key/value specific to cross-attention.
if cross_test_params is None or cross_test_params.packed_qkvo.packed_qkv
is None, this reflects that in decode-phase cross attention there
is no growth in the key and value tensors.
attn.forward() is passed attn_type=AttentionType.ENCODER_DECODER
in order to configure the kernel invocation for encoder/decoder cross-
attention.
Arguments:
* test_rsrcs: TestResources instance; this function relies on the kv_cache
and attn (Attention wrapper instance) fields
* decoder_test_params: decoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
query field
* cross_test_params: encoder/decoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention
Returns:
* Attention.forward() applied to packed_{query,key,value}, kv_cache
& attn_metadata
'''
assert
decoder_test_params
.
packed_qkvo
.
packed_qkv
is
not
None
attn_type
=
AttentionType
.
ENCODER_DECODER
attn
=
test_rsrcs
.
attn
kv_cache
=
test_rsrcs
.
kv_cache
if
cross_test_params
is
None
:
key
=
None
value
=
None
else
:
cross_pckd_qkv
=
cross_test_params
.
packed_qkvo
.
packed_qkv
key
=
(
None
if
cross_pckd_qkv
is
None
else
cross_pckd_qkv
.
key
)
value
=
(
None
if
cross_pckd_qkv
is
None
else
cross_pckd_qkv
.
value
)
return
attn
.
forward
(
decoder_test_params
.
packed_qkvo
.
packed_qkv
.
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
attn_type
=
attn_type
)
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"backend_name"
,
BACKEND_NAMES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"max_dec_seq_len"
,
MAX_DEC_SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"max_enc_seq_len"
,
MAX_ENC_SEQ_LENS
)
def
test_encoder_only
(
num_heads
:
int
,
head_size
:
int
,
backend_name
:
str
,
batch_size
:
int
,
block_size
:
int
,
max_dec_seq_len
:
int
,
max_enc_seq_len
:
int
,
monkeypatch
):
# Force Attention wrapper backend
override_backend_env_variable
(
monkeypatch
,
backend_name
)
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
test_pt
=
TestPoint
(
num_heads
,
head_size
,
backend_name
,
batch_size
,
block_size
,
max_dec_seq_len
,
max_enc_seq_len
,
4096
)
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
test_rsrcs
=
_make_test_resources
(
test_pt
)
# Construct encoder attention test params (only used
# during prefill)
enc_test_params
=
_encoder_attn_setup
(
test_pt
,
test_rsrcs
)
# Shared prefill metadata structure
prephase_attn_metadata
:
AttentionMetadata
=
make_test_metadata
(
test_rsrcs
.
attn_backend
,
True
,
None
,
decoder_test_params
=
None
,
encoder_test_params
=
enc_test_params
,
cross_test_params
=
None
,
device
=
CUDA_DEVICE
)
# PREFILL: encoder attention
enc_pckd_act_out
:
torch
.
Tensor
=
(
_run_encoder_attention_test
(
test_rsrcs
.
attn
,
enc_test_params
,
prephase_attn_metadata
))
# - Is encoder attention result correct?
assert_actual_matches_ideal
(
enc_test_params
,
enc_pckd_act_out
)
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"backend_name"
,
BACKEND_NAMES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"max_dec_seq_len"
,
MAX_DEC_SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"max_enc_seq_len"
,
MAX_ENC_SEQ_LENS
)
def
test_e2e_enc_dec_attn
(
num_heads
:
int
,
head_size
:
int
,
backend_name
:
str
,
batch_size
:
int
,
block_size
:
int
,
max_dec_seq_len
:
int
,
max_enc_seq_len
:
int
,
monkeypatch
,
)
->
None
:
'''
End-to-end encoder/decoder test:
* Construct fake test vectors for (1) encoder attention,
(2) decoder self-attention, and (3) encoder/decoder cross-attention
* Construct (1) attention metadata structure with self- and cross-attention
attributes for prefill-phase, and (2) an analogous attention metadata
structure but for decode-phase
* Test attention steps in the following order
* Encoder attention
* Prefill self-attention
* Prefill cross-attention
* Decode self-attention
* Decode cross-attention
* Besides being reflective of realistic use-cases, this order would
exacerbate any accidental overlap in the self-/cross-attention
block tables, which one hopes to avoid
* Validate output correctness against ideal reference attention
implementation
Block tables are constructed such that cross-attention KV cache is in a
higher, non-intersecting address-space than self-attention KV cache.
Self- and cross-attention share the same query tensor but not the K/V
tensors. Self-attention K/Vs must have the same seq len as Q while
cross-attention K/Vs are allowed to differ in seq len, as is often the case
for cross-attention.
This test utilizes PyTest monkey patching to force the attention backend
via an environment variable.
Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip().
Note on metadata: there is a single attention metadata structure shared by
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
and a single one shared by all decode-phase attention operations
(decoder & enc/dec cross.) This is intended to reflect the behavior
of ModelRunner, which constructs a single attention metadata structure for
each prefill or decode run. A realistic scenario would rely on the
attention backend to utilize the appropriate attention metadata fields
according to the value of attn_metadata.attention_type. Thus, this test is
organized so as to confirm that the backend-under-test can handle a
shared prefill attention metadata structure & a shared decode attention
metadata structure.
'''
# Force Attention wrapper backend
override_backend_env_variable
(
monkeypatch
,
backend_name
)
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
test_pt
=
TestPoint
(
num_heads
,
head_size
,
backend_name
,
batch_size
,
block_size
,
max_dec_seq_len
,
max_enc_seq_len
,
4096
)
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
test_rsrcs
=
_make_test_resources
(
test_pt
)
# Construct encoder attention test params (only used
# during prefill)
enc_test_params
=
_encoder_attn_setup
(
test_pt
,
test_rsrcs
)
# Construct Decoder self-attention prefill-phase & decode-phase
# test params, including query/key/value tensors, decoder self-attention
# memory-mapping. cross_block_base_addr is the uppermost address in the
# decoder self-attention block-table, i.e. a base address which the
# encoder/decoder cross-attention block-table may build downward toward.
(
dec_qkv
,
prephase_dec_test_params
,
decphase_dec_test_params
,
cross_block_base_addr
,
)
=
_decoder_attn_setup
(
test_pt
,
test_rsrcs
)
# Construct encoder/decoder cross-attention prefill-phase & decode-phase
# test params, including key/value tensors, cross-attention memory-mapping
(
prephase_cross_test_params
,
decphase_cross_test_params
,
)
=
_enc_dec_cross_attn_setup_reuses_query
(
dec_qkv
,
enc_test_params
,
prephase_dec_test_params
,
test_pt
,
test_rsrcs
,
block_base_addr
=
cross_block_base_addr
)
# Shared prefill metadata structure
assert
prephase_dec_test_params
.
packed_qkvo
.
packed_qkv
is
not
None
prephase_attn_metadata
:
AttentionMetadata
=
make_test_metadata
(
test_rsrcs
.
attn_backend
,
True
,
prephase_dec_test_params
.
packed_qkvo
.
packed_qkv
.
q_seq_lens
,
decoder_test_params
=
prephase_dec_test_params
,
encoder_test_params
=
enc_test_params
,
cross_test_params
=
prephase_cross_test_params
,
device
=
CUDA_DEVICE
)
# PREFILL: encoder attention
enc_pckd_act_out
=
_run_encoder_attention_test
(
test_rsrcs
.
attn
,
enc_test_params
,
prephase_attn_metadata
)
# - Is encoder attention result correct?
assert_actual_matches_ideal
(
enc_test_params
,
enc_pckd_act_out
)
# PREFILL: decoder self-attention test
prephase_dec_pckd_act_out
=
_run_decoder_self_attention_test
(
test_rsrcs
,
prephase_dec_test_params
,
prephase_attn_metadata
)
# - Is prefill decoder self-attention correct?
assert_actual_matches_ideal
(
prephase_dec_test_params
,
prephase_dec_pckd_act_out
)
# PREFILL: encoder/decoder cross-attention test
prephase_cross_pckd_act_out
=
_run_encoder_decoder_cross_attention_test
(
test_rsrcs
,
prephase_dec_test_params
,
prephase_cross_test_params
,
prephase_attn_metadata
)
# - Is prefill encoder/decoder cross-attention correct?
assert_actual_matches_ideal
(
prephase_cross_test_params
,
prephase_cross_pckd_act_out
)
# DECODE: build decode-phase attention metadata
decphase_attn_metadata
:
AttentionMetadata
=
make_test_metadata
(
test_rsrcs
.
attn_backend
,
False
,
dec_qkv
.
q_seq_lens
,
decoder_test_params
=
decphase_dec_test_params
,
encoder_test_params
=
enc_test_params
,
cross_test_params
=
decphase_cross_test_params
,
device
=
CUDA_DEVICE
)
# DECODE: decoder self-attention test
decphase_dec_pckd_act_out
=
_run_decoder_self_attention_test
(
test_rsrcs
,
decphase_dec_test_params
,
decphase_attn_metadata
)
# - Is decode-phase decoder self-attention correct?
assert_actual_matches_ideal
(
decphase_dec_test_params
,
decphase_dec_pckd_act_out
)
# DECODE: encoder/decoder cross-attention test
decphase_cross_pckd_act_out
=
_run_encoder_decoder_cross_attention_test
(
test_rsrcs
,
decphase_dec_test_params
,
None
,
decphase_attn_metadata
)
# - Is decode-phase encoder/decoder cross-attention correct?
assert_actual_matches_ideal
(
decphase_cross_test_params
,
decphase_cross_pckd_act_out
)
tests/kernels/test_flash_attn.py
View file @
e7c1b7f3
...
@@ -20,12 +20,13 @@ def ref_paged_attn(
...
@@ -20,12 +20,13 @@ def ref_paged_attn(
block_tables
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
scale
:
float
,
scale
:
float
,
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_seqs
=
len
(
query_lens
)
num_seqs
=
len
(
query_lens
)
block_tables
=
block_tables
.
cpu
().
numpy
()
block_tables
=
block_tables
.
cpu
().
numpy
()
_
,
block_size
,
num_kv_heads
,
head_size
=
key_cache
.
shape
_
,
block_size
,
num_kv_heads
,
head_size
=
key_cache
.
shape
outputs
=
[]
outputs
:
List
[
torch
.
Tensor
]
=
[]
start_idx
=
0
start_idx
=
0
for
i
in
range
(
num_seqs
):
for
i
in
range
(
num_seqs
):
query_len
=
query_lens
[
i
]
query_len
=
query_lens
[
i
]
...
@@ -53,6 +54,8 @@ def ref_paged_attn(
...
@@ -53,6 +54,8 @@ def ref_paged_attn(
(
query_len
+
sliding_window
)
+
(
query_len
+
sliding_window
)
+
1
).
bool
().
logical_not
()
1
).
bool
().
logical_not
()
mask
|=
sliding_window_mask
mask
|=
sliding_window_mask
if
soft_cap
is
not
None
:
attn
=
soft_cap
*
torch
.
tanh
(
attn
/
soft_cap
)
attn
.
masked_fill_
(
mask
,
float
(
"-inf"
))
attn
.
masked_fill_
(
mask
,
float
(
"-inf"
))
attn
=
torch
.
softmax
(
attn
,
dim
=-
1
).
to
(
v
.
dtype
)
attn
=
torch
.
softmax
(
attn
,
dim
=-
1
).
to
(
v
.
dtype
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn
,
v
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn
,
v
)
...
@@ -68,13 +71,15 @@ def ref_paged_attn(
...
@@ -68,13 +71,15 @@ def ref_paged_attn(
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
torch
.
inference_mode
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
torch
.
inference_mode
()
def
test_flash_attn_with_paged_kv
(
def
test_flash_attn_with_paged_kv
(
kv_lens
:
List
[
Tuple
[
int
,
int
]
],
kv_lens
:
List
[
int
],
num_heads
:
Tuple
[
int
,
int
],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
)
->
None
:
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
torch
.
cuda
.
manual_seed_all
(
0
)
...
@@ -108,6 +113,7 @@ def test_flash_attn_with_paged_kv(
...
@@ -108,6 +113,7 @@ def test_flash_attn_with_paged_kv(
causal
=
True
,
causal
=
True
,
block_table
=
block_tables
,
block_table
=
block_tables
,
cache_seqlens
=
kv_lens_tensor
,
cache_seqlens
=
kv_lens_tensor
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
).
squeeze
(
1
)
).
squeeze
(
1
)
ref_output
=
ref_paged_attn
(
ref_output
=
ref_paged_attn
(
...
@@ -118,6 +124,7 @@ def test_flash_attn_with_paged_kv(
...
@@ -118,6 +124,7 @@ def test_flash_attn_with_paged_kv(
kv_lens
=
kv_lens
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
scale
=
scale
,
scale
=
scale
,
soft_cap
=
soft_cap
,
)
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
...
@@ -129,7 +136,8 @@ def test_flash_attn_with_paged_kv(
...
@@ -129,7 +136,8 @@ def test_flash_attn_with_paged_kv(
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
torch
.
inference_mode
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
torch
.
inference_mode
()
def
test_varlen_with_paged_kv
(
def
test_varlen_with_paged_kv
(
seq_lens
:
List
[
Tuple
[
int
,
int
]],
seq_lens
:
List
[
Tuple
[
int
,
int
]],
num_heads
:
Tuple
[
int
,
int
],
num_heads
:
Tuple
[
int
,
int
],
...
@@ -137,6 +145,7 @@ def test_varlen_with_paged_kv(
...
@@ -137,6 +145,7 @@ def test_varlen_with_paged_kv(
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
)
->
None
:
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
torch
.
cuda
.
manual_seed_all
(
0
)
...
@@ -163,10 +172,6 @@ def test_varlen_with_paged_kv(
...
@@ -163,10 +172,6 @@ def test_varlen_with_paged_kv(
head_size
,
head_size
,
dtype
=
dtype
)
dtype
=
dtype
)
value_cache
=
torch
.
randn_like
(
key_cache
)
value_cache
=
torch
.
randn_like
(
key_cache
)
# Normalize the scale of the key and value caches to mitigate
# numerical instability.
key_cache
/=
head_size
**
0.5
value_cache
/=
head_size
**
0.5
cu_query_lens
=
torch
.
tensor
([
0
]
+
query_lens
,
cu_query_lens
=
torch
.
tensor
([
0
]
+
query_lens
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
...
@@ -192,6 +197,7 @@ def test_varlen_with_paged_kv(
...
@@ -192,6 +197,7 @@ def test_varlen_with_paged_kv(
causal
=
True
,
causal
=
True
,
window_size
=
window_size
,
window_size
=
window_size
,
block_table
=
block_tables
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
)
)
ref_output
=
ref_paged_attn
(
ref_output
=
ref_paged_attn
(
...
@@ -203,6 +209,7 @@ def test_varlen_with_paged_kv(
...
@@ -203,6 +209,7 @@ def test_varlen_with_paged_kv(
block_tables
=
block_tables
,
block_tables
=
block_tables
,
scale
=
scale
,
scale
=
scale
,
sliding_window
=
sliding_window
,
sliding_window
=
sliding_window
,
soft_cap
=
soft_cap
,
)
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
tests/kernels/test_flashinfer.py
0 → 100644
View file @
e7c1b7f3
from
typing
import
List
,
Optional
,
Tuple
import
flashinfer
import
pytest
import
torch
NUM_HEADS
=
[(
16
,
16
),
(
32
,
8
),
(
64
,
8
)]
HEAD_SIZES
=
[
128
,
256
]
BLOCK_SIZES
=
[
16
,
32
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
NUM_BLOCKS
=
32768
# Large enough to test overflow in index calculation.
def
ref_paged_attn
(
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
query_lens
:
List
[
int
],
kv_lens
:
List
[
int
],
block_tables
:
torch
.
Tensor
,
scale
:
float
,
sliding_window
:
Optional
[
int
]
=
None
,
soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
num_seqs
=
len
(
query_lens
)
block_tables
=
block_tables
.
cpu
().
numpy
()
_
,
block_size
,
num_kv_heads
,
head_size
=
key_cache
.
shape
outputs
:
List
[
torch
.
Tensor
]
=
[]
start_idx
=
0
for
i
in
range
(
num_seqs
):
query_len
=
query_lens
[
i
]
kv_len
=
kv_lens
[
i
]
q
=
query
[
start_idx
:
start_idx
+
query_len
]
q
*=
scale
num_kv_blocks
=
(
kv_len
+
block_size
-
1
)
//
block_size
block_indices
=
block_tables
[
i
,
:
num_kv_blocks
]
k
=
key_cache
[
block_indices
].
view
(
-
1
,
num_kv_heads
,
head_size
)
k
=
k
[:
kv_len
]
v
=
value_cache
[
block_indices
].
view
(
-
1
,
num_kv_heads
,
head_size
)
v
=
v
[:
kv_len
]
if
q
.
shape
[
1
]
!=
k
.
shape
[
1
]:
k
=
torch
.
repeat_interleave
(
k
,
q
.
shape
[
1
]
//
k
.
shape
[
1
],
dim
=
1
)
v
=
torch
.
repeat_interleave
(
v
,
q
.
shape
[
1
]
//
v
.
shape
[
1
],
dim
=
1
)
attn
=
torch
.
einsum
(
"qhd,khd->hqk"
,
q
,
k
).
float
()
empty_mask
=
torch
.
ones
(
query_len
,
kv_len
)
mask
=
torch
.
triu
(
empty_mask
,
diagonal
=
kv_len
-
query_len
+
1
).
bool
()
if
sliding_window
is
not
None
:
sliding_window_mask
=
torch
.
triu
(
empty_mask
,
diagonal
=
kv_len
-
(
query_len
+
sliding_window
)
+
1
).
bool
().
logical_not
()
mask
|=
sliding_window_mask
if
soft_cap
is
not
None
:
attn
=
soft_cap
*
torch
.
tanh
(
attn
/
soft_cap
)
attn
.
masked_fill_
(
mask
,
float
(
"-inf"
))
attn
=
torch
.
softmax
(
attn
,
dim
=-
1
).
to
(
v
.
dtype
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn
,
v
)
outputs
.
append
(
out
)
start_idx
+=
query_len
return
torch
.
cat
(
outputs
,
dim
=
0
)
@
pytest
.
mark
.
parametrize
(
"kv_lens"
,
[[
1328
,
18
,
463
],
[
1
,
54
,
293
,
70
]])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
30.0
,
50.0
])
@
torch
.
inference_mode
def
test_flashinfer_decode_with_paged_kv
(
kv_lens
:
List
[
int
],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
])
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
num_seqs
=
len
(
kv_lens
)
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
max_kv_len
=
max
(
kv_lens
)
scale
=
head_size
**-
0.5
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
key_value_cache
=
torch
.
randn
(
NUM_BLOCKS
,
2
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
key_cache
=
key_value_cache
[:,
0
,
:,
:,
:].
squeeze
(
1
)
value_cache
=
key_value_cache
[:,
1
,
:,
:,
:].
squeeze
(
1
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
NUM_BLOCKS
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
kv_indptr
=
[
0
]
kv_indices
=
[]
kv_last_page_lens
=
[]
for
i
in
range
(
num_seqs
):
seq_len
=
kv_lens
[
i
]
assert
seq_len
>
0
num_blocks
=
(
seq_len
+
block_size
-
1
)
//
block_size
kv_indices
.
extend
(
block_tables
[
i
,
:
num_blocks
])
kv_indptr
.
append
(
kv_indptr
[
-
1
]
+
num_blocks
)
kv_last_page_len
=
seq_len
%
block_size
if
kv_last_page_len
==
0
:
kv_last_page_len
=
block_size
kv_last_page_lens
.
append
(
kv_last_page_len
)
kv_indptr
=
torch
.
tensor
(
kv_indptr
,
dtype
=
torch
.
int32
)
kv_indices
=
torch
.
tensor
(
kv_indices
,
dtype
=
torch
.
int32
)
kv_last_page_lens
=
torch
.
tensor
(
kv_last_page_lens
,
dtype
=
torch
.
int32
)
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
)
wrapper
=
flashinfer
.
\
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
)
wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_lens
,
num_query_heads
,
num_kv_heads
,
head_size
,
block_size
,
"NONE"
,
data_type
=
dtype
)
output
=
wrapper
.
forward
(
query
,
key_value_cache
,
logits_soft_cap
=
soft_cap
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
query_lens
=
[
1
]
*
num_seqs
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
scale
=
scale
,
soft_cap
=
soft_cap
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
[[(
1
,
1328
),
(
5
,
18
),
(
129
,
463
)]])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
30.0
,
50.0
])
@
torch
.
inference_mode
def
test_flashinfer_prefill_with_paged_kv
(
seq_lens
:
List
[
Tuple
[
int
,
int
]],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
])
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
max_kv_len
=
max
(
kv_lens
)
scale
=
head_size
**-
0.5
query
=
torch
.
randn
(
sum
(
query_lens
),
num_query_heads
,
head_size
,
dtype
=
dtype
)
key_value_cache
=
torch
.
randn
(
NUM_BLOCKS
,
2
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
key_cache
=
key_value_cache
[:,
0
,
:,
:,
:].
squeeze
(
1
)
value_cache
=
key_value_cache
[:,
1
,
:,
:,
:].
squeeze
(
1
)
# Normalize the scale of the key and value caches to mitigate
# numerical instability.
key_cache
/=
head_size
**
0.5
value_cache
/=
head_size
**
0.5
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
NUM_BLOCKS
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
qo_indptr
=
[
0
]
kv_indptr
=
[
0
]
kv_indices
=
[]
kv_last_page_lens
=
[]
for
i
in
range
(
num_seqs
):
seq_len
=
kv_lens
[
i
]
assert
seq_len
>
0
num_blocks
=
(
seq_len
+
block_size
-
1
)
//
block_size
kv_indices
.
extend
(
block_tables
[
i
,
:
num_blocks
])
kv_indptr
.
append
(
kv_indptr
[
-
1
]
+
num_blocks
)
kv_last_page_len
=
seq_len
%
block_size
if
kv_last_page_len
==
0
:
kv_last_page_len
=
block_size
kv_last_page_lens
.
append
(
kv_last_page_len
)
qo_indptr
.
append
(
qo_indptr
[
-
1
]
+
query_lens
[
i
])
qo_indptr
=
torch
.
tensor
(
qo_indptr
,
dtype
=
torch
.
int32
)
kv_indptr
=
torch
.
tensor
(
kv_indptr
,
dtype
=
torch
.
int32
)
kv_indices
=
torch
.
tensor
(
kv_indices
,
dtype
=
torch
.
int32
)
kv_last_page_lens
=
torch
.
tensor
(
kv_last_page_lens
,
dtype
=
torch
.
int32
)
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
)
wrapper
=
flashinfer
.
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
)
wrapper
.
begin_forward
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_lens
,
num_query_heads
,
num_kv_heads
,
head_size
,
block_size
,
)
output
=
wrapper
.
forward
(
query
,
key_value_cache
,
logits_soft_cap
=
soft_cap
,
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
query_lens
=
query_lens
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
scale
=
scale
,
soft_cap
=
soft_cap
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
tests/kernels/test_fp8_quant.py
0 → 100644
View file @
e7c1b7f3
import
pytest
import
torch
import
vllm._custom_ops
as
ops
from
tests.kernels.quant_utils
import
(
ref_dynamic_per_tensor_fp8_quant
,
ref_dynamic_per_token_quant
)
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
HIDDEN_SIZES
=
[
1
,
2
,
3
,
4
,
16
,
67
,
768
,
2048
,
5120
,
5137
,
8192
,
8193
]
# Arbitrary values for testing
HIDDEN_SIZES
+=
list
(
range
(
1024
,
1033
))
# vectorized conversion edge cases
NUM_TOKENS
=
[
1
,
7
,
83
,
4096
]
# Arbitrary values for testing
SCALE_UBS
=
[
True
,
False
]
SEEDS
=
[
0
]
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"scale_ub"
,
SCALE_UBS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
test_dynamic_per_token_fp8_quant
(
num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
scale_ub
:
bool
,
seed
:
int
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
rand
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
+
1e-6
# avoid nans
scale_ub
=
torch
.
mean
(
x
).
to
(
dtype
=
torch
.
float32
,
device
=
'cuda'
)
\
if
scale_ub
else
None
ref_out
,
ref_scales
=
ref_dynamic_per_token_quant
(
x
,
torch
.
float8_e4m3fn
,
scale_ub
)
ops_out
,
ops_scales
=
ops
.
scaled_fp8_quant
(
x
,
scale_ub
=
scale_ub
,
use_per_token_if_dynamic
=
True
)
assert
torch
.
allclose
(
ref_scales
,
ops_scales
)
assert
torch
.
allclose
(
ref_out
.
to
(
dtype
=
torch
.
float32
),
ops_out
.
to
(
dtype
=
torch
.
float32
))
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
test_dynamic_per_tensor_fp8_quant
(
num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
rand
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
ref_out
,
ref_scale
=
ref_dynamic_per_tensor_fp8_quant
(
x
)
ops_out
,
ops_scale
=
ops
.
scaled_fp8_quant
(
x
)
assert
torch
.
allclose
(
ref_scale
,
ops_scale
)
assert
torch
.
allclose
(
ref_out
.
to
(
dtype
=
torch
.
float32
),
ops_out
.
to
(
dtype
=
torch
.
float32
))
# Regression test for a case with large activations where an int32 index cannot
# represent the number of elements.
@
torch
.
inference_mode
()
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
def
test_fp8_quant_large
(
seed
:
int
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
num_tokens
=
1024000
# Mistral-Nemo's max_position_embeddings
hidden_size
=
1152
# Smallest hidden_size to reproduce the error
dtype
=
torch
.
bfloat16
x
=
torch
.
rand
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
ref_out
,
scale
=
ref_dynamic_per_tensor_fp8_quant
(
x
)
ops_out
,
_
=
ops
.
scaled_fp8_quant
(
x
,
scale
)
# Minimize memory footprint in this test by freeing x and upconverting
# the outputs in place. (torch.allclose does not support fp8)
del
x
ref_out
=
ref_out
.
to
(
dtype
=
dtype
)
ops_out
=
ops_out
.
to
(
dtype
=
dtype
)
assert
torch
.
allclose
(
ref_out
,
ops_out
)
tests/kernels/test_int8_quant.py
View file @
e7c1b7f3
import
pytest
import
pytest
import
torch
import
torch
# ruff: noqa: F401
from
tests.kernels.quant_utils
import
ref_dynamic_per_token_quant
import
vllm._C
from
vllm._custom_ops
import
scaled_int8_quant
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
HIDDEN_SIZES
=
[
16
,
67
,
768
,
2048
,
5120
,
5137
,
8192
,
HIDDEN_SIZES
=
[
16
,
67
,
768
,
2048
,
5120
,
5137
,
8192
,
...
@@ -21,23 +21,16 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
...
@@ -21,23 +21,16 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype
:
torch
.
dtype
,
seed
:
int
)
->
None
:
dtype
:
torch
.
dtype
,
seed
:
int
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
int8_traits
=
torch
.
iinfo
(
torch
.
int8
)
x
=
torch
.
rand
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1000
x
=
torch
.
rand
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1000
x_token_max
,
_
=
x
.
max
(
dim
=
1
)
# reference
x_token_max
=
x_token_max
.
to
(
dtype
=
torch
.
float32
)
ref_out
,
ref_scales
=
ref_dynamic_per_token_quant
(
x
,
torch
.
int8
)
scales
=
(
x_token_max
/
float
(
127.0
))[:,
None
].
to
(
device
=
"cuda"
,
# kernel
dtype
=
torch
.
float32
)
ops_out
,
ops_scales
=
scaled_int8_quant
(
x
)
torch_out
=
(
x
/
scales
).
round
().
clamp
(
int8_traits
.
min
,
int8_traits
.
max
).
to
(
torch
.
int8
)
ops_out
=
torch
.
empty_like
(
x
,
dtype
=
torch
.
int8
,
device
=
"cuda"
)
scales_out
=
torch
.
empty_like
(
scales
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
torch
.
ops
.
_C
.
dynamic_scaled_int8_quant
(
ops_out
,
x
,
scales_out
)
assert
torch
.
allclose
(
scales
_out
,
scales
)
assert
torch
.
allclose
(
ops_
scales
,
ref_
scales
)
assert
torch
.
allclose
(
torch
_out
,
ops
_out
,
assert
torch
.
allclose
(
ops
_out
,
ref
_out
,
atol
=
1
)
# big atol to account for rounding errors
atol
=
1
)
# big atol to account for rounding errors
...
@@ -55,12 +48,11 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
...
@@ -55,12 +48,11 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
int8_traits
=
torch
.
iinfo
(
torch
.
int8
)
int8_traits
=
torch
.
iinfo
(
torch
.
int8
)
x
=
torch
.
rand
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1000
x
=
torch
.
rand
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1000
scale
=
torch
.
tensor
([
scale
],
dtype
=
torch
.
float32
,
device
=
"cuda"
)
out1
=
(
x
/
scale
).
round
().
clamp
(
int8_traits
.
min
,
out1
=
(
x
/
scale
).
round
().
clamp
(
int8_traits
.
min
,
int8_traits
.
max
).
to
(
torch
.
int8
)
int8_traits
.
max
).
to
(
torch
.
int8
)
out2
=
torch
.
empty_like
(
x
,
dtype
=
torch
.
int8
)
out2
,
_
=
scaled_int8_quant
(
x
,
scale
)
scale_argument
=
torch
.
tensor
([
scale
],
dtype
=
torch
.
float32
,
device
=
"cuda"
)
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
out2
,
x
,
scale_argument
)
assert
torch
.
allclose
(
out1
,
out2
,
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1
)
# big atol to account for rounding errors
atol
=
1
)
# big atol to account for rounding errors
tests/kernels/test_marlin_gemm.py
View file @
e7c1b7f3
...
@@ -5,23 +5,33 @@ Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
...
@@ -5,23 +5,33 @@ Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
import
pytest
import
pytest
import
torch
import
torch
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQ_MARLIN_24_MAX_PARALLEL
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_MAX_PARALLEL
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
)
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
)
from
vllm.model_executor.layers.quantization.utils.marlin_perms
import
(
from
vllm.model_executor.layers.quantization.qqq
import
(
marlin_perm
)
MARLIN_QQQ_MAX_PARALLEL
,
MARLIN_QQQ_MIN_THREAD_N
,
MARLIN_QQQ_SUPPORTED_GROUP_SIZES
,
MARLIN_QQQ_SUPPORTED_NUM_BITS
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
MarlinWorkspace
,
compute_max_diff
,
is_marlin_supported
,
marlin_24_quantize
,
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
marlin_quantize
,
marlin_weights
)
MARLIN_SUPPORTED_GROUP_SIZES
,
marlin_make_empty_g_idx
,
marlin_permute_scales
,
query_marlin_supported_quant_types
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
pack_fp8_to_int32
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
MarlinWorkspace
,
awq_marlin_quantize
,
get_weight_perm
,
marlin_quantize
,
marlin_weights
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test_24
import
(
marlin_24_quantize
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq
import
(
# noqa: E501
marlin_qqq_quantize
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
gptq_pack
,
quantize_weights
,
sort_weights
)
awq_pack
,
gptq_pack
,
gptq_quantize_weights
,
quantize_weights
,
sort_weights
)
ACT_ORDER_OPTS
=
[
False
,
True
]
ACT_ORDER_OPTS
=
[
False
,
True
]
K_FULL_OPTS
=
[
False
,
True
]
K_FULL_OPTS
=
[
False
,
True
]
USE_FP32_REDUCE_OPTS
=
[
False
,
True
]
MARLIN_K_CHUNKS
=
[
128
]
MARLIN_K_CHUNKS
=
[
128
]
MARLIN_N_CHUNKS
=
[
64
,
128
,
256
]
MARLIN_N_CHUNKS
=
[
64
,
128
,
256
]
...
@@ -38,21 +48,29 @@ MNK_FACTORS = [
...
@@ -38,21 +48,29 @@ MNK_FACTORS = [
(
67
,
13
,
11
),
(
67
,
13
,
11
),
]
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
def
compute_max_diff
(
output
,
output_ref
):
return
torch
.
mean
(
torch
.
abs
(
output
-
output_ref
))
/
torch
.
mean
(
torch
.
abs
(
output_ref
))
def
rand_data
(
shape
):
def
rand_data
(
shape
,
dtype
=
torch
.
float16
):
return
torch
.
randn
(
shape
,
dtype
=
torch
.
half
,
device
=
"cuda"
)
return
torch
.
randn
(
shape
,
dtype
=
dtype
,
device
=
"cuda"
)
@
pytest
.
mark
.
skipif
(
not
is_
marlin_supported
(
),
@
pytest
.
mark
.
skipif
(
not
is_
quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
@
pytest
.
mark
.
parametrize
(
"group_size"
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
)
query_marlin_supported_quant_types
(
False
))
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
def
test_marlin_repack
(
k_chunk
,
n_chunk
,
num_bits
,
group_size
,
act_order
,
def
test_
gptq_
marlin_repack
(
k_chunk
,
n_chunk
,
quant_type
,
group_size
,
mnk_factors
):
act_order
,
mnk_factors
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_m
=
m_factor
...
@@ -77,11 +95,11 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
...
@@ -77,11 +95,11 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
b_weight
=
rand_data
((
size_k
,
size_n
))
b_weight
=
rand_data
((
size_k
,
size_n
))
# Quantize (and apply act_order if provided)
# Quantize (and apply act_order if provided)
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
b_weight
,
num_bits
,
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
gptq_
quantize_weights
(
group_size
,
act_order
)
b_weight
,
quant_type
,
group_size
,
act_order
)
# Pack to GPTQ format
# Pack to GPTQ format
q_w_gptq
=
gptq_pack
(
q_w
,
num
_bits
,
size_k
,
size_n
)
q_w_gptq
=
gptq_pack
(
q_w
,
quant_type
.
size
_bits
,
size_k
,
size_n
)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
# increasing
...
@@ -90,8 +108,9 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
...
@@ -90,8 +108,9 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
q_w
,
g_idx
,
sort_indices
=
sort_weights
(
q_w
,
g_idx
)
q_w
,
g_idx
,
sort_indices
=
sort_weights
(
q_w
,
g_idx
)
# Pack to Marlin format
# Pack to Marlin format
marlin_q_w_1
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
weight_perm
=
get_weight_perm
(
quant_type
.
size_bits
)
marlin_perm
[
num_bits
])
marlin_q_w_1
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
quant_type
.
size_bits
,
weight_perm
)
# Run Marlin repack GPU kernel
# Run Marlin repack GPU kernel
marlin_q_w_2
=
ops
.
gptq_marlin_repack
(
marlin_q_w_2
=
ops
.
gptq_marlin_repack
(
...
@@ -99,30 +118,85 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
...
@@ -99,30 +118,85 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
sort_indices
,
sort_indices
,
size_k
,
size_k
,
size_n
,
size_n
,
num_bits
,
quant_type
.
size_bits
,
)
torch
.
cuda
.
synchronize
()
assert
torch
.
allclose
(
marlin_q_w_1
,
marlin_q_w_2
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
query_marlin_supported_quant_types
(
False
))
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
def
test_awq_marlin_repack
(
k_chunk
,
n_chunk
,
quant_type
,
group_size
,
mnk_factors
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Create input
b_weight
=
rand_data
((
size_k
,
size_n
))
# Quantize
w_ref
,
q_w
,
s
,
zp
=
quantize_weights
(
b_weight
,
quant_type
,
group_size
,
zero_points
=
True
)
# Pack to AWQ format
q_w_awq
=
awq_pack
(
q_w
,
quant_type
.
size_bits
,
size_k
,
size_n
)
# Pack to Marlin format
weight_perm
=
get_weight_perm
(
quant_type
.
size_bits
)
marlin_q_w_1
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
quant_type
.
size_bits
,
weight_perm
)
# Run Marlin repack GPU kernel
marlin_q_w_2
=
ops
.
awq_marlin_repack
(
q_w_awq
,
size_k
,
size_n
,
quant_type
.
size_bits
,
)
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
assert
torch
.
allclose
(
marlin_q_w_1
,
marlin_q_w_2
)
assert
torch
.
allclose
(
marlin_q_w_1
,
marlin_q_w_2
)
@
pytest
.
mark
.
skipif
(
not
is_
marlin_supported
(
),
@
pytest
.
mark
.
skipif
(
not
is_
quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
@
pytest
.
mark
.
parametrize
(
"group_size"
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
)
query_marlin_supported_quant_types
(
False
))
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"is_k_full"
,
K_FULL_OPTS
)
@
pytest
.
mark
.
parametrize
(
"is_k_full"
,
K_FULL_OPTS
)
def
test_marlin_gemm
(
@
pytest
.
mark
.
parametrize
(
"use_fp32_reduce"
,
USE_FP32_REDUCE_OPTS
)
def
test_gptq_marlin_gemm
(
k_chunk
,
k_chunk
,
n_chunk
,
n_chunk
,
num_bits
,
quant_type
,
group_size
,
group_size
,
mnk_factors
,
mnk_factors
,
act_order
,
act_order
,
is_k_full
,
is_k_full
,
use_fp32_reduce
,
):
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
m_factor
,
n_factor
,
k_factor
=
mnk_factors
...
@@ -143,7 +217,9 @@ def test_marlin_gemm(
...
@@ -143,7 +217,9 @@ def test_marlin_gemm(
b_weight
=
rand_data
((
size_k
,
size_n
))
b_weight
=
rand_data
((
size_k
,
size_n
))
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
_
=
marlin_quantize
(
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
_
=
marlin_quantize
(
b_weight
,
num_bits
,
group_size
,
act_order
)
b_weight
,
quant_type
,
group_size
,
act_order
)
marlin_zp
=
marlin_make_empty_g_idx
(
marlin_s
.
device
)
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
GPTQ_MARLIN_MAX_PARALLEL
)
...
@@ -152,14 +228,17 @@ def test_marlin_gemm(
...
@@ -152,14 +228,17 @@ def test_marlin_gemm(
a_input
,
a_input
,
marlin_q_w
,
marlin_q_w
,
marlin_s
,
marlin_s
,
marlin_zp
,
g_idx
,
g_idx
,
sort_indices
,
sort_indices
,
workspace
.
scratch
,
workspace
.
scratch
,
num_bits
,
quant_type
,
a_input
.
shape
[
0
],
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
,
is_k_full
=
is_k_full
,
has_zp
=
False
,
use_fp32_reduce
=
use_fp32_reduce
,
)
)
output_ref
=
torch
.
matmul
(
a_input
,
w_ref
)
output_ref
=
torch
.
matmul
(
a_input
,
w_ref
)
...
@@ -171,14 +250,15 @@ def test_marlin_gemm(
...
@@ -171,14 +250,15 @@ def test_marlin_gemm(
assert
max_diff
<
0.04
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_
marlin_supported
(
),
@
pytest
.
mark
.
skipif
(
not
is_
quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_24_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_24_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_24_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_24_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"
num_bits
"
,
GPTQ_MARLIN_24_SUPPORTED_
NUM_BIT
S
)
@
pytest
.
mark
.
parametrize
(
"
quant_type
"
,
GPTQ_MARLIN_24_SUPPORTED_
QUANT_TYPE
S
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
def
test_marlin_24_gemm
(
k_chunk
,
n_chunk
,
num_bits
,
group_size
,
mnk_factors
):
def
test_gptq_marlin_24_gemm
(
k_chunk
,
n_chunk
,
quant_type
,
group_size
,
mnk_factors
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_m
=
m_factor
...
@@ -192,7 +272,7 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
...
@@ -192,7 +272,7 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
b_weight
=
rand_data
((
size_k
,
size_n
))
b_weight
=
rand_data
((
size_k
,
size_n
))
(
w_24_ref
,
marlin_24_q_w_comp
,
marlin_24_meta
,
(
w_24_ref
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
)
=
marlin_24_quantize
(
b_weight
,
num_bits
,
group_size
)
marlin_24_s
)
=
marlin_24_quantize
(
b_weight
,
quant_type
,
group_size
)
workspace_24
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
workspace_24
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_MAX_PARALLEL
)
GPTQ_MARLIN_24_MAX_PARALLEL
)
...
@@ -205,7 +285,7 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
...
@@ -205,7 +285,7 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
marlin_24_meta
,
marlin_24_meta
,
marlin_24_s
,
marlin_24_s
,
workspace_24
.
scratch
,
workspace_24
.
scratch
,
num_bits
,
quant_type
,
a_input
.
shape
[
0
],
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
a_input
.
shape
[
1
],
...
@@ -217,3 +297,204 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
...
@@ -217,3 +297,204 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
print
(
"max_diff = {}"
.
format
(
max_diff
))
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
-
1
])
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
def
test_fp8_marlin_gemm
(
k_chunk
,
n_chunk
,
num_bits
,
group_size
,
mnk_factors
,
dtype
,
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
print
(
f
"groupsize =
{
group_size
}
"
)
a_input
=
rand_data
((
size_m
,
size_k
),
dtype
=
dtype
)
b_weight
=
rand_data
((
size_k
,
size_n
),
dtype
=
dtype
)
# WEIGHTS
fp8_weight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
b_weight
,
scale
=
None
)
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight
=
pack_fp8_to_int32
(
fp8_weight
)
# Repack weights to marlin format
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
packed_gptq_qweight
,
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
"cuda"
),
size_k
=
size_k
,
size_n
=
size_n
,
num_bits
=
8
,
)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales
=
weight_scale
.
repeat
(
1
,
size_n
).
to
(
a_input
.
dtype
).
to
(
"cuda"
)
# Permute scales
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
size_k
,
size_n
=
size_n
,
group_size
=-
1
)
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
output
=
ops
.
fp8_marlin_gemm
(
a
=
a_input
,
b_q_weight
=
marlin_qweight
,
b_scales
=
marlin_scales
,
workspace
=
workspace
.
scratch
,
num_bits
=
num_bits
,
size_m
=
a_input
.
shape
[
0
],
size_n
=
b_weight
.
shape
[
1
],
size_k
=
a_input
.
shape
[
1
],
)
output_ref
=
torch
.
matmul
(
a_input
,
b_weight
)
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
query_marlin_supported_quant_types
(
True
))
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"use_fp32_reduce"
,
USE_FP32_REDUCE_OPTS
)
def
test_awq_marlin_gemm
(
k_chunk
,
n_chunk
,
quant_type
,
group_size
,
mnk_factors
,
use_fp32_reduce
,
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
print
(
f
"groupsize =
{
group_size
}
"
)
a_input
=
rand_data
((
size_m
,
size_k
))
b_weight
=
rand_data
((
size_k
,
size_n
))
w_ref
,
marlin_q_w
,
marlin_s
,
marlin_zp
=
awq_marlin_quantize
(
b_weight
,
quant_type
,
group_size
)
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
marlin_q_w
.
device
)
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
marlin_q_w
.
device
)
is_k_full
=
True
has_zp
=
True
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
output
=
ops
.
gptq_marlin_gemm
(
a_input
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_idx
,
sort_indices
,
workspace
.
scratch
,
quant_type
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
=
is_k_full
,
has_zp
=
has_zp
,
use_fp32_reduce
=
use_fp32_reduce
,
)
output_ref
=
torch
.
matmul
(
a_input
,
w_ref
)
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"qqq"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
MARLIN_QQQ_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_QQQ_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
def
test_marlin_qqq_gemm
(
k_chunk
,
n_chunk
,
num_bits
,
group_size
,
mnk_factors
,
):
int8_traits
=
torch
.
iinfo
(
torch
.
int8
)
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
print
(
f
"groupsize =
{
group_size
}
"
)
a_input
=
rand_data
((
size_m
,
size_k
))
b_weight
=
rand_data
((
size_k
,
size_n
))
# Quantize activations
s_a
=
a_input
.
abs
().
max
(
dim
=-
1
,
keepdim
=
True
)[
0
].
div
(
int8_traits
.
max
).
to
(
torch
.
float
)
q_a
=
(
a_input
/
s_a
).
round
().
clamp
(
int8_traits
.
min
,
int8_traits
.
max
).
to
(
torch
.
int8
)
# Quantize weights
w_ref
,
marlin_qqq_q_w
,
marlin_qqq_s_group
,
marlin_qqq_s_channel
=
\
marlin_qqq_quantize
(
b_weight
,
num_bits
,
group_size
)
workspace
=
MarlinWorkspace
(
size_n
,
MARLIN_QQQ_MIN_THREAD_N
,
MARLIN_QQQ_MAX_PARALLEL
)
output
=
ops
.
marlin_qqq_gemm
(
q_a
,
marlin_qqq_q_w
,
s_a
,
marlin_qqq_s_channel
,
marlin_qqq_s_group
,
workspace
.
scratch
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
)
output_ref
=
torch
.
matmul
(
q_a
.
half
()
*
s_a
.
half
(),
w_ref
)
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
\ No newline at end of file
tests/kernels/test_moe.py
View file @
e7c1b7f3
...
@@ -29,7 +29,7 @@ def torch_moe(a, w1, w2, score, topk):
...
@@ -29,7 +29,7 @@ def torch_moe(a, w1, w2, score, topk):
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)).
sum
(
dim
=
1
)
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)).
sum
(
dim
=
1
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1024
*
128
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
511
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
511
,
1024
])
@
pytest
.
mark
.
parametrize
(
"e"
,
[
8
,
64
])
@
pytest
.
mark
.
parametrize
(
"e"
,
[
8
,
64
])
...
@@ -77,8 +77,8 @@ def test_mixtral_moe(dtype: torch.dtype):
...
@@ -77,8 +77,8 @@ def test_mixtral_moe(dtype: torch.dtype):
for
i
in
range
(
config
.
num_local_experts
):
for
i
in
range
(
config
.
num_local_experts
):
weights
=
(
hf_moe
.
experts
[
i
].
w1
.
weight
.
data
,
weights
=
(
hf_moe
.
experts
[
i
].
w1
.
weight
.
data
,
hf_moe
.
experts
[
i
].
w3
.
weight
.
data
)
hf_moe
.
experts
[
i
].
w3
.
weight
.
data
)
vllm_moe
.
w13_weight
[
i
][:]
=
torch
.
cat
(
weights
,
dim
=
0
)
vllm_moe
.
experts
.
w13_weight
[
i
][:]
=
torch
.
cat
(
weights
,
dim
=
0
)
vllm_moe
.
w2_weight
[
i
][:]
=
hf_moe
.
experts
[
i
].
w2
.
weight
.
data
vllm_moe
.
experts
.
w2_weight
[
i
][:]
=
hf_moe
.
experts
[
i
].
w2
.
weight
.
data
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs
=
torch
.
randn
((
1
,
64
,
config
.
hidden_size
)).
to
(
dtype
).
to
(
"cuda"
)
hf_inputs
=
torch
.
randn
((
1
,
64
,
config
.
hidden_size
)).
to
(
dtype
).
to
(
"cuda"
)
...
...
tests/kernels/test_pos_encoding.py
View file @
e7c1b7f3
from
itertools
import
accumulate
,
product
from
itertools
import
accumulate
,
product
from
typing
import
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
import
pytest
import
pytest
import
torch
import
torch
...
@@ -10,7 +10,7 @@ from .allclose_default import get_default_atol, get_default_rtol
...
@@ -10,7 +10,7 @@ from .allclose_default import get_default_atol, get_default_rtol
IS_NEOX_STYLE
=
[
True
,
False
]
IS_NEOX_STYLE
=
[
True
,
False
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
192
,
256
]
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
]
ROTARY_DIMS
=
[
None
,
32
]
# None means rotary dim == head size
ROTARY_DIMS
=
[
None
,
32
]
# None means rotary dim == head size
NUM_HEADS
=
[
7
,
17
]
# Arbitrary values for testing
NUM_HEADS
=
[
7
,
17
]
# Arbitrary values for testing
BATCH_SIZES
=
[
1
,
5
]
# Arbitrary values for testing
BATCH_SIZES
=
[
1
,
5
]
# Arbitrary values for testing
...
@@ -126,7 +126,7 @@ def test_batched_rotary_embedding(
...
@@ -126,7 +126,7 @@ def test_batched_rotary_embedding(
query
,
query
,
key
,
key
,
offsets
=
torch
.
zeros
(
batch_size
*
seq_len
,
offsets
=
torch
.
zeros
(
batch_size
*
seq_len
,
dtype
=
int
,
dtype
=
torch
.
long
,
device
=
device
))
device
=
device
))
# Compare the results.
# Compare the results.
assert
torch
.
allclose
(
out_query
,
assert
torch
.
allclose
(
out_query
,
...
@@ -214,20 +214,16 @@ def test_batched_rotary_embedding_multi_lora(
...
@@ -214,20 +214,16 @@ def test_batched_rotary_embedding_multi_lora(
def
test_rope_module_cache
():
def
test_rope_module_cache
():
MAX_POSITIONS
=
[
123
,
1234
]
MAX_POSITIONS
=
[
123
,
1234
]
BASES
=
[
10000
,
1000000
]
BASES
=
[
10000
,
1000000
]
ROPE_SCALINGS
=
[
ROPE_SCALINGS
=
(
None
,
{
None
,
{
"type"
:
"linear"
,
"type"
:
"linear"
,
"factor"
:
(
1
,
)
"factor"
:
(
1
,
)
},
{
},
{
"type"
:
"dynamic"
,
"type"
:
"dynamic"
,
"factor"
:
1
"factor"
:
1
})
}
settings
=
(
HEAD_SIZES
,
ROTARY_DIMS
,
MAX_POSITIONS
,
BASES
,
IS_NEOX_STYLE
,
]
ROPE_SCALINGS
,
DTYPES
)
settings
=
[
rope_setting_id_map
:
Dict
[
str
,
int
]
=
{}
HEAD_SIZES
,
ROTARY_DIMS
,
MAX_POSITIONS
,
BASES
,
IS_NEOX_STYLE
,
ROPE_SCALINGS
,
DTYPES
]
rope_setting_id_map
=
{}
for
setting
in
product
(
*
settings
):
for
setting
in
product
(
*
settings
):
head_size
,
rotary_dim
,
max_position
,
base
,
\
head_size
,
rotary_dim
,
max_position
,
base
,
\
is_neox_stype
,
rope_scaling
,
dtype
=
setting
is_neox_stype
,
rope_scaling
,
dtype
=
setting
...
...
tests/kernels/test_sampler.py
View file @
e7c1b7f3
import
gc
import
gc
from
unittest.mock
import
patch
import
pytest
import
pytest
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
vllm.model_executor.layers.ops.sample
import
(
from
vllm.model_executor.layers.ops.sample
import
(
_sample_triton
,
MAX_TRITON_N_COLS
,
_uniform_to_exponential
,
get_num_triton_sampler_splits
,
_uniform_to_exponential
,
sample
)
sample
)
from
vllm.model_executor.sampling_metadata
import
SamplingTensors
from
vllm.model_executor.sampling_metadata
import
SamplingTensors
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.triton_utils.libentry
import
LibEntry
from
vllm.triton_utils.sample
import
(
MAX_TRITON_N_COLS
,
get_num_triton_sampler_splits
)
SINGLE_SPLIT_VOCAB_SIZE
=
32000
# llama/mistral/mixtral vocab size
SINGLE_SPLIT_VOCAB_SIZE
=
32000
# llama/mistral/mixtral vocab size
MULTI_SPLIT_VOCAB_SIZE
=
MAX_TRITON_N_COLS
+
100
MULTI_SPLIT_VOCAB_SIZE
=
MAX_TRITON_N_COLS
+
100
...
@@ -75,15 +79,20 @@ def test_sample_decoding_only(random_sampling, max_best_of,
...
@@ -75,15 +79,20 @@ def test_sample_decoding_only(random_sampling, max_best_of,
seeds
=
torch
.
randint
(
1
,
seeds
=
torch
.
randint
(
1
,
torch
.
iinfo
(
torch
.
long
).
max
,
(
n_splits
,
bs
),
torch
.
iinfo
(
torch
.
long
).
max
,
(
n_splits
,
bs
),
device
=
"cuda"
).
mul_
(
random_sampling_mask
)
device
=
"cuda"
).
mul_
(
random_sampling_mask
)
sampled_tokens
,
sampled_logprobs
,
sampled_modified_probs
=
sample
(
#The current _sample_triton does not utilize the
probs
=
probs
,
# libentry decoration. The purpose of adding this patch is to test
logprobs
=
logprobs
,
# the correctness of libentry.
sample_indices
=
sample_indices
,
with
patch
(
"vllm.model_executor.layers.ops.sample._sample_triton"
,
seeds
=
seeds
,
LibEntry
(
_sample_triton
)):
max_best_of
=
max_best_of
,
sampled_tokens
,
sampled_logprobs
,
sampled_modified_probs
=
sample
(
modify_greedy_probs
=
modify_greedy_probs
,
probs
=
probs
,
save_logprobs
=
save_logprobs
,
logprobs
=
logprobs
,
_save_modified_probs
=
True
)
sample_indices
=
sample_indices
,
seeds
=
seeds
,
max_best_of
=
max_best_of
,
modify_greedy_probs
=
modify_greedy_probs
,
save_logprobs
=
save_logprobs
,
_save_modified_probs
=
True
)
assert
sampled_tokens
.
shape
==
(
bs
,
max_best_of
)
assert
sampled_tokens
.
shape
==
(
bs
,
max_best_of
)
for
i
in
range
(
bs
):
for
i
in
range
(
bs
):
assert
torch
.
all
(
sampled_tokens
[
i
]
==
i
*
(
vocab_size
//
bs
))
assert
torch
.
all
(
sampled_tokens
[
i
]
==
i
*
(
vocab_size
//
bs
))
...
@@ -129,6 +138,7 @@ def test_sample_decoding_only(random_sampling, max_best_of,
...
@@ -129,6 +138,7 @@ def test_sample_decoding_only(random_sampling, max_best_of,
[
SINGLE_SPLIT_VOCAB_SIZE
,
MULTI_SPLIT_VOCAB_SIZE
])
[
SINGLE_SPLIT_VOCAB_SIZE
,
MULTI_SPLIT_VOCAB_SIZE
])
def
test_sample_prompt_logprobs
(
random_sampling
,
max_best_of
,
def
test_sample_prompt_logprobs
(
random_sampling
,
max_best_of
,
modify_greedy_probs
,
seed
,
vocab_size
):
modify_greedy_probs
,
seed
,
vocab_size
):
set_random_seed
(
seed
)
set_random_seed
(
seed
)
prompt_sizes
=
[
16
,
32
,
64
,
128
]
*
2
prompt_sizes
=
[
16
,
32
,
64
,
128
]
*
2
samples
=
8
samples
=
8
...
@@ -156,14 +166,17 @@ def test_sample_prompt_logprobs(random_sampling, max_best_of,
...
@@ -156,14 +166,17 @@ def test_sample_prompt_logprobs(random_sampling, max_best_of,
seeds
=
torch
.
randint
(
1
,
seeds
=
torch
.
randint
(
1
,
torch
.
iinfo
(
torch
.
long
).
max
,
(
n_splits
,
samples
),
torch
.
iinfo
(
torch
.
long
).
max
,
(
n_splits
,
samples
),
device
=
"cuda"
).
mul_
(
random_sampling_mask
)
device
=
"cuda"
).
mul_
(
random_sampling_mask
)
sampled_tokens
,
sampled_logprobs
,
_
=
sample
(
#ditto
probs
=
probs
,
with
patch
(
"vllm.model_executor.layers.ops.sample._sample_triton"
,
logprobs
=
logprobs
,
LibEntry
(
_sample_triton
)):
sample_indices
=
sample_indices
,
sampled_tokens
,
sampled_logprobs
,
_
=
sample
(
seeds
=
seeds
,
probs
=
probs
,
max_best_of
=
max_best_of
,
logprobs
=
logprobs
,
modify_greedy_probs
=
modify_greedy_probs
,
sample_indices
=
sample_indices
,
save_logprobs
=
True
)
seeds
=
seeds
,
max_best_of
=
max_best_of
,
modify_greedy_probs
=
modify_greedy_probs
,
save_logprobs
=
True
)
assert
sampled_tokens
.
shape
==
(
samples
,
max_best_of
)
assert
sampled_tokens
.
shape
==
(
samples
,
max_best_of
)
assert
sampled_logprobs
.
shape
==
(
samples
,
max_best_of
)
assert
sampled_logprobs
.
shape
==
(
samples
,
max_best_of
)
for
i
,
t
in
enumerate
(
sample_indices
):
for
i
,
t
in
enumerate
(
sample_indices
):
...
...
tests/kernels/utils.py
View file @
e7c1b7f3
"""Kernel test utils"""
"""Kernel test utils"""
import
itertools
import
random
from
numbers
import
Number
from
typing
import
Any
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Union
import
pytest
import
pytest
import
torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.xformers
import
XFormersBackend
from
vllm.utils
import
make_tensor_with_pad
# String name of register which may be set in order to
# force auto-selection of attention backend by Attention
# wrapper
STR_BACKEND_ENV_VAR
:
str
=
"VLLM_ATTENTION_BACKEND"
STR_BACKEND_ENV_VAR
:
str
=
"VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL
:
str
=
"FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL
:
str
=
"TORCH_SDPA"
STR_ROCM_FLASH_ATTN_VAL
:
str
=
"ROCM_FLASH"
STR_XFORMERS_ATTN_VAL
:
str
=
"XFORMERS"
STR_FLASH_ATTN_VAL
:
str
=
"FLASH_ATTN"
STR_FLASH_ATTN_VAL
:
str
=
"FLASH_ATTN"
STR_INVALID_VAL
:
str
=
"INVALID"
STR_INVALID_VAL
:
str
=
"INVALID"
class
QKVInputs
(
NamedTuple
):
'''
Data structure for representing unpacked attention inputs,
query/key/values and their sequence lengths.
Attributes:
* {query,key,value}: unpacked (batch_size x padded_seq_len x
num_heads x head_size) attention inputs
* q_seq_lens: query sequence lengths list
* kv_seq_lens: shared key/value sequence lengths list
'''
query
:
torch
.
Tensor
key
:
torch
.
Tensor
value
:
torch
.
Tensor
q_seq_lens
:
List
[
int
]
kv_seq_lens
:
List
[
int
]
class
QKVO
(
NamedTuple
):
'''
Data structure for representing unpacked attention inputs,
alongside unpacked known-correct attention output
Attributes:
* qkv: unpacked (batch_size x padded_seq_len x
num_heads x head_size) attention inputs
* ideal_output: unpacked (batch_size x padded_seq_len x
num_heads x head_size) known-correct attention output
'''
qkv
:
QKVInputs
ideal_output
:
torch
.
Tensor
class
PackedQKVInputs
(
NamedTuple
):
'''
Data structure for representing packed attention inputs
Attributes:
* {query,key,value}: packed (number_of_tokens x num_heads
x head_size) attention inputs
* q_start_loc_list: list of query start locations within packed tensor
* kv_start_loc_list: shared list of key/value start locations within
packed tensor
* q_seq_lens: query sequence lengths list
* kv_seq_lens: shared key/value sequence lengths list
'''
query
:
torch
.
Tensor
key
:
torch
.
Tensor
value
:
torch
.
Tensor
q_start_loc_list
:
Optional
[
List
[
int
]]
kv_start_loc_list
:
Optional
[
List
[
int
]]
q_seq_lens
:
Optional
[
List
[
int
]]
kv_seq_lens
:
Optional
[
List
[
int
]]
class
PackedQKVO
(
NamedTuple
):
'''
Data structure for representing packed attention inputs,
alongside packed known-correct attention output
Attributes:
* packed_qkv: packed (number_of_tokens x num_heads
x head_size) attention inputs
* ideal_output: packed (number_of_tokens x num_heads
x head_size) known-correct attention output
'''
packed_qkv
:
Optional
[
PackedQKVInputs
]
ideal_output
:
torch
.
Tensor
class
KVMemoryMap
(
NamedTuple
):
'''
Data structure for encapsulating KV cache memory mapping.
Attributes:
* block_tables: KV cache block tables
* slot_mapping: mapping of sequence offset to physical address
'''
block_tables
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
class
PhaseTestParameters
(
NamedTuple
):
'''
Data structure for encapsulating the test parameters
for a given test "phase" (prefill or decode phase) and attention
scenario (encoder, decoder-self, encoder/decoder-cross)
Attributes:
* packed_qkvo: packed (number_of_tokens x num_heads
x head_size) attention inputs & known-correct
output
* kv_mmap: KV cache memory mapping, specific to this test phase &
attention scenario
'''
packed_qkvo
:
PackedQKVO
kv_mmap
:
Optional
[
KVMemoryMap
]
def
maybe_make_int_tensor
(
_list
:
Optional
[
List
[
int
]],
device
:
Union
[
torch
.
device
,
str
],
)
->
torch
.
Tensor
:
'''
Convert Python int list to a 1D int torch.Tensor on `device`
Returns:
* If _list is not None: 1D int torch.Tensor on `device`
* None otherwise
'''
return
None
if
_list
is
None
else
torch
.
tensor
(
_list
,
dtype
=
torch
.
int
,
device
=
device
)
def
maybe_make_long_tensor
(
_list
:
Optional
[
List
[
int
]],
device
:
Union
[
torch
.
device
,
str
],
)
->
torch
.
Tensor
:
'''
Convert Python int list to a 1D long torch.Tensor on `device`
Returns:
* If _list is not None: 1D long torch.Tensor on `device`
* None otherwise
'''
return
None
if
_list
is
None
else
torch
.
tensor
(
_list
,
dtype
=
torch
.
long
,
device
=
device
)
def
maybe_max
(
_list
:
Optional
[
List
])
->
Optional
[
Number
]:
'''
Returns:
* If _list is not None: max(_list)
* None otherwise
'''
return
None
if
_list
is
None
else
max
(
_list
)
def
make_causal_mask
(
q_max_seq_len
:
int
,
kv_max_seq_len
:
int
,
)
->
torch
.
Tensor
:
'''
Create a q_max_seq_len x kv_max_seq_len causal mask
Arguments:
* q_max_seq_len: query max seq len
* kv_max_seq_len: key/value max seq len
Returns:
* 2D tensor, q_max_seq_len x kv_max_seq_len
'''
# Create a matrix where entry (i, j) is True if i >= j
mask
=
torch
.
triu
(
torch
.
ones
(
q_max_seq_len
,
kv_max_seq_len
),
diagonal
=
1
)
# Replace True with float('-inf') and False with 0
mask
=
mask
.
masked_fill
(
mask
==
1
,
float
(
'-inf'
)).
masked_fill
(
mask
==
0
,
0.0
)
return
mask
def
override_backend_env_variable
(
mpatch
:
pytest
.
MonkeyPatch
,
def
override_backend_env_variable
(
mpatch
:
pytest
.
MonkeyPatch
,
backend_name
:
str
)
->
None
:
backend_name
:
str
)
->
None
:
'''
'''
...
@@ -20,3 +219,724 @@ def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
...
@@ -20,3 +219,724 @@ def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
* backend_name: attention backend name to force
* backend_name: attention backend name to force
'''
'''
mpatch
.
setenv
(
STR_BACKEND_ENV_VAR
,
backend_name
)
mpatch
.
setenv
(
STR_BACKEND_ENV_VAR
,
backend_name
)
def
ref_masked_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
scale
:
float
,
custom_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
q_seq_lens
:
Optional
[
List
]
=
None
,
kv_seq_lens
:
Optional
[
List
]
=
None
)
->
torch
.
Tensor
:
'''
"Golden" masked attention reference. Supports two types of masking:
* Basic attention mask, utilizing {q,kv}_seq_lens args to mask out
padding elements
* Custom attention mask, which can force an arbitrary mask tensor, i.e.
causal
Arguments:
* query: batch_size x q_padded_seq_len x num_heads x head_size
* key: batch_size x kv_padded_seq_len x num_heads x head_size
* value: batch_size x kv_padded_seq_len x num_heads x head_size
* scale: Attention scale factor
* custom_mask: custom attention mask; good place to inject a causal
attention mask
* q_seq_lens: list of unpadded query seq_lens for each batch index
* kv_seq_lens: list of unpadded key/value seq_lens for each batch index
Returns:
* Attention result, batch_size x q_padded_seq_len x num_heads x head_size
'''
assert
q_seq_lens
is
not
None
assert
kv_seq_lens
is
not
None
batch_size
=
query
.
shape
[
0
]
assert
(
len
(
q_seq_lens
)
==
batch_size
)
assert
(
len
(
kv_seq_lens
)
==
batch_size
)
attn_weights
=
scale
*
torch
.
einsum
(
"bqhd,bkhd->bhqk"
,
query
,
key
).
float
()
# Basic attention mask, derived from seq lens
if
(
q_seq_lens
is
not
None
)
or
(
kv_seq_lens
is
not
None
):
attn_mask
=
torch
.
zeros_like
(
attn_weights
)
if
q_seq_lens
is
not
None
:
for
bdx
,
plen
in
enumerate
(
q_seq_lens
):
attn_mask
[
bdx
,
:,
plen
:,
:]
=
-
torch
.
inf
if
kv_seq_lens
is
not
None
:
for
bdx
,
plen
in
enumerate
(
kv_seq_lens
):
attn_mask
[
bdx
,
:,
:,
plen
:]
=
-
torch
.
inf
attn_weights
=
attn_weights
+
attn_mask
.
float
()
# Custom attention mask
if
custom_mask
is
not
None
:
attn_weights
=
attn_weights
+
custom_mask
.
float
()
attn_weights
=
torch
.
softmax
(
attn_weights
,
dim
=-
1
).
to
(
value
.
dtype
)
out
=
torch
.
einsum
(
"bhqk,bkhd->bqhd"
,
attn_weights
,
value
)
return
out
def
make_qkv
(
batch_size
:
int
,
max_q_seq_len
:
int
,
max_kv_seq_len
:
Optional
[
int
],
num_heads
:
int
,
head_size
:
int
,
device
:
Union
[
torch
.
device
,
str
],
force_kv_seq_lens
:
Optional
[
List
[
int
]]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
ENCODER_DECODER
,
force_max_len
:
bool
=
False
,
)
->
Tuple
[
QKVInputs
,
QKVInputs
,
QKVInputs
]:
'''
Construct QKV test tensors for self- and cross-attention.
Generates three query/key/value triplets:
* "Baseline" query/key/value (for input to reference attention function)
* "Prefill" query/key/value (last sequence offset zero'd out, for use as
input to prefill kernel)
* "Decode" query/key/value (only the last sequence offset from baseline,
for use as input to decode kernel)
Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v
seqlens
Arguments:
* batch_size
* max_q_seq_len: max query seq len
* max_kv_seq_len: max key/value seq len
* num_heads
* head_size
* is_encoder_decoder_attn: if True, query seqlen may differ from
key/value seqlen (as is often the case for cross-attention);
o/w, query/key/value seqlens match at each batch index
(max_kv_seq_len is unused)
* force_kv_seq_lens: if not None, overrides kv sequence lengths
* attn_type: encoder, decoder self, or enc/dec cross attention
* force_max_len: if True, all query seqlens are max_q_seq_len; o/w query
seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens
and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False
* device: CPU or CUDA device
Returns:
* Overall QKVInputs structure (containing full unpacked Q/K/V tensors)
* Prefill QKVInputs structure (containing all but the last sequence offset)
* Decode QKVInputs structure (containing all only the last sequence offset)
'''
if
force_max_len
:
q_seq_lens
=
[
max_q_seq_len
for
_
in
range
(
batch_size
)]
else
:
q_seq_lens
=
[
random
.
randint
(
2
,
max_q_seq_len
)
for
_
in
range
(
batch_size
)
]
kv_seq_lens
=
None
if
force_kv_seq_lens
is
not
None
:
kv_seq_lens
=
force_kv_seq_lens
elif
attn_type
!=
AttentionType
.
ENCODER_DECODER
:
# K,V seq lens match Q for self-attention
kv_seq_lens
=
q_seq_lens
else
:
# K,V seq lens are distinct from Q seq lens & random
assert
max_kv_seq_len
is
not
None
if
force_max_len
:
kv_seq_lens
=
[
max_kv_seq_len
]
*
batch_size
else
:
kv_seq_lens
=
[
random
.
randint
(
2
,
max_kv_seq_len
)
for
_
in
range
(
batch_size
)
]
query
=
torch
.
rand
(
(
batch_size
,
max_q_seq_len
,
num_heads
,
head_size
)).
to
(
device
)
key
=
torch
.
rand
(
(
batch_size
,
max_kv_seq_len
,
num_heads
,
head_size
)).
to
(
device
)
value
=
torch
.
rand
(
(
batch_size
,
max_kv_seq_len
,
num_heads
,
head_size
)).
to
(
device
)
prefill_query
=
torch
.
zeros
(
(
batch_size
,
max_q_seq_len
,
num_heads
,
head_size
)).
to
(
device
)
prefill_key
=
torch
.
zeros
(
(
batch_size
,
max_kv_seq_len
,
num_heads
,
head_size
)).
to
(
device
)
prefill_value
=
torch
.
zeros
(
(
batch_size
,
max_kv_seq_len
,
num_heads
,
head_size
)).
to
(
device
)
decode_query
=
torch
.
zeros
(
(
batch_size
,
1
,
num_heads
,
head_size
)).
to
(
device
)
decode_key
=
torch
.
zeros
((
batch_size
,
1
,
num_heads
,
head_size
)).
to
(
device
)
decode_value
=
torch
.
zeros
(
(
batch_size
,
1
,
num_heads
,
head_size
)).
to
(
device
)
for
bdx
,
(
q_seq_len
,
kv_seq_len
)
in
enumerate
(
zip
(
q_seq_lens
,
kv_seq_lens
)):
query
[
bdx
,
q_seq_len
:,
:,
:]
=
0
key
[
bdx
,
kv_seq_len
:,
:,
:]
=
0
value
[
bdx
,
kv_seq_len
:,
:,
:]
=
0
prefill_query
[
bdx
,
0
:(
q_seq_len
-
1
),
:,
:]
=
query
[
bdx
,
0
:(
q_seq_len
-
1
),
:,
:]
prefill_key
[
bdx
,
0
:(
kv_seq_len
-
1
),
:,
:]
=
key
[
bdx
,
0
:(
kv_seq_len
-
1
),
:,
:]
prefill_value
[
bdx
,
0
:(
kv_seq_len
-
1
),
:,
:]
=
value
[
bdx
,
0
:(
kv_seq_len
-
1
),
:,
:]
decode_query
[
bdx
,
:,
:,
:]
=
query
[
bdx
,
(
q_seq_len
-
1
):
q_seq_len
,
:,
:]
decode_key
[
bdx
,
:,
:,
:]
=
key
[
bdx
,
(
kv_seq_len
-
1
):
kv_seq_len
,
:,
:]
decode_value
[
bdx
,
:,
:,
:]
=
value
[
bdx
,
(
kv_seq_len
-
1
):
kv_seq_len
,
:,
:]
prefill_q_seq_lens
=
[
plen
-
1
for
plen
in
q_seq_lens
]
prefill_kv_seq_lens
=
[
plen
-
1
for
plen
in
kv_seq_lens
]
decode_q_seq_lens
=
[
1
for
_
in
q_seq_lens
]
decode_kv_seq_lens
=
[
1
for
_
in
kv_seq_lens
]
return
(
QKVInputs
(
query
,
# Overall QKV inputs
key
,
value
,
q_seq_lens
,
kv_seq_lens
),
QKVInputs
(
prefill_query
,
# Prefill subset of QKV sequences
prefill_key
,
prefill_value
,
prefill_q_seq_lens
,
prefill_kv_seq_lens
),
QKVInputs
(
decode_query
,
# Decode subset of KV sequences
decode_key
,
decode_value
,
decode_q_seq_lens
,
decode_kv_seq_lens
))
def
pack_tensor
(
unpacked_tensor
:
torch
.
Tensor
,
seq_lens
:
List
[
int
],
device
:
Union
[
torch
.
device
,
str
])
->
Tuple
[
torch
.
Tensor
,
List
[
int
]]:
'''
Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
unpadded number_of_tokens x num_heads x head_size tensor, where
number_of_tokens = sum(seq_lens)
Arguments:
* unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size
* seq_lens: list of token counts for each seq
* device: CPU or CUDA device
Returns
* packed_tensor: number_of_tokens x num_heads x head_size
* start_loc_list: start idx of each batch elt in packed_tensor; [0] +
list(itertools.accumulate(seq_lens))
'''
num_tok
=
sum
(
seq_lens
)
num_heads
=
unpacked_tensor
.
shape
[
-
2
]
head_size
=
unpacked_tensor
.
shape
[
-
1
]
start_loc_list
=
[
0
]
+
list
(
itertools
.
accumulate
(
seq_lens
))
packed_tensor
=
torch
.
zeros
((
num_tok
,
num_heads
,
head_size
),
device
=
device
)
for
bdx
,
(
seq_len
,
start_loc
)
in
enumerate
(
zip
(
seq_lens
,
start_loc_list
)):
packed_tensor
[
start_loc
:(
start_loc
+
seq_len
),
:,
:]
=
unpacked_tensor
[
bdx
,
:
seq_len
,
:,
:]
return
packed_tensor
,
start_loc_list
def
pack_qkv
(
qkv
:
QKVInputs
,
device
:
Union
[
torch
.
device
,
str
])
->
PackedQKVInputs
:
'''
Individually pack each of Q, K and V, each with dimensions batch_size x
padded_seq_len x num_heads x head_size, into respective number_of_tokens x
num_heads x head_size tensors.
For Q, number_of_tokens = sum(q_seq_lens).
For K and V, number_of_tokens = sum(kv_seq_lens)
Arguments:
* qkv: Unpacked (batch_size x padded_seq_len x num_heads x head_size)
attention inputs
* device: CPU or CUDA device
Returns
* Packed (number_of_tokens x num_heads x head_size) QKV inputs
derived from unpacked inputs
'''
if
qkv
.
query
is
None
:
packed_query
=
None
q_start_loc_list
=
None
else
:
packed_query
,
q_start_loc_list
=
pack_tensor
(
qkv
.
query
,
qkv
.
q_seq_lens
,
device
=
device
)
packed_key
,
kv_start_loc_list
=
pack_tensor
(
qkv
.
key
,
qkv
.
kv_seq_lens
,
device
=
device
)
packed_value
,
_
=
pack_tensor
(
qkv
.
value
,
qkv
.
kv_seq_lens
,
device
=
device
)
return
PackedQKVInputs
(
packed_query
,
packed_key
,
packed_value
,
q_start_loc_list
,
kv_start_loc_list
,
(
None
if
q_start_loc_list
is
None
else
qkv
.
q_seq_lens
),
qkv
.
kv_seq_lens
)
def
make_backend
(
backend_name
:
str
)
->
AttentionBackend
:
'''
Construct the backend instance determined by the backend_name string
argument.
"XFORMERS" -> construct xformers backend
TODO: other backends
Note: at time of writing the Attention wrapper automatically selects
its own backend for Attention.forward(); so the backend instance which
you generate with this function is not meant to be used for *running*
inference, but rather for generating compatible metadata structures
using backend.make_metadata()
Returns:
* Backend instance
'''
if
backend_name
==
STR_XFORMERS_ATTN_VAL
:
return
XFormersBackend
()
raise
AssertionError
(
f
"Unrecognized backend_name
{
backend_name
}
for unit test"
)
def
_make_metadata_tensors
(
seq_lens
:
Optional
[
List
[
int
]],
context_lens
:
Optional
[
List
[
int
]],
encoder_seq_lens
:
Optional
[
List
[
int
]],
device
:
Union
[
torch
.
device
,
str
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Any
,
Any
,
Optional
[
List
[
int
]],
torch
.
Tensor
,
Optional
[
int
]]:
'''
Build scalar & tensor values required to build attention metadata structure.
Arguments:
* seq_lens: list of token-counts for each decoder input seq
* context_lens: list of context length values for each seq
* encoder_seq_lens: list of token-counts for each encoder input seq
* device: CPU or CUDA device
Returns:
* seq_lens_tensor: decoder seq_lens list, as tensor
* context_lens_tensor: context_lens list, as tensor
* max_context_len: max(context_lens)
* max_seq_len: max(seq_lens)
* seq_start_loc: start idx of each sequence
* max_encoder_seq_len: encoder seq_lens list, as tensor
'''
seq_lens_tensor
=
maybe_make_int_tensor
(
seq_lens
,
device
)
context_lens_tensor
=
maybe_make_int_tensor
(
context_lens
,
device
)
max_context_len
=
maybe_max
(
context_lens
)
max_seq_len
=
maybe_max
(
seq_lens
)
encoder_seq_lens_tensor
=
maybe_make_int_tensor
(
encoder_seq_lens
,
device
)
max_encoder_seq_len
=
(
None
if
encoder_seq_lens
is
None
else
max
(
encoder_seq_lens
))
seq_start_loc
=
None
return
(
seq_lens_tensor
,
context_lens_tensor
,
max_context_len
,
max_seq_len
,
seq_start_loc
,
encoder_seq_lens_tensor
,
max_encoder_seq_len
)
def
make_kv_cache
(
num_blocks
:
int
,
num_heads
:
int
,
head_size
:
int
,
block_size
:
int
,
device
:
Union
[
torch
.
device
,
str
],
default_val
:
float
=
0.0
)
->
torch
.
Tensor
:
'''
Create a fake KV cache.
Arguments:
* num_blocks: number of blocks in the KV cache
* num_heads: number of attention heads
* head_size: head dimension
* block_size: number of offsets within a block
* device: CPU or CUDA device
* default_val: initialization value for KV cache elements
Returns:
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
'''
kv_cache
=
torch
.
rand
(
(
2
,
num_blocks
,
block_size
*
num_heads
*
head_size
)).
to
(
device
)
if
default_val
is
not
None
:
kv_cache
[:,
:,
:]
=
default_val
return
kv_cache
def
_num_tokens_to_min_blocks
(
num_tokens
:
int
,
block_size
:
int
)
->
int
:
'''
Compute the minimum number of blocks required to hold num_tokens tokens,
given block_size
'''
return
(
num_tokens
+
block_size
)
//
block_size
def
make_empty_slot_mapping_tensor
(
device
:
Union
[
torch
.
device
,
str
]):
return
maybe_make_long_tensor
([],
device
)
def
make_empty_block_tables_tensor
(
device
:
Union
[
torch
.
device
,
str
]):
return
torch
.
tensor
([],
device
=
device
)
def
split_slot_mapping
(
slot_mapping_list
:
torch
.
Tensor
,
seq_lens
:
List
[
int
],
device
:
Union
[
torch
.
device
,
str
]):
'''
Split a slot mapping into valid prefill- and decode-phase slot mappings.
Context:
* Your goal is to test (1) prefill of N prompts, with prompt-lengths
{K_i
\\
forall i
\\
in [0,N)}, followed by (2) decoding of a single token
for all N prompts (N tokens total); the resultant sequence lengths
after decode would be {K_i + 1 for i
\\
in [0,N)}
* The test you want to do requires (1) having the prefill slot mapping
for all tokens present during prefill, the number of which is
M =
\\
sum_i{K_i}, and (2) having the decode slot mapping for all N
decoded tokens
This function consumes a single 1D slot mapping, which is the
concatenation of N slot mappings each of length K_i + 1 (corresponding
to the sequence lengths after decode), with a total length of
P =
\\
sum_i{K_i + 1} = M + N
The prefill-phase slot mapping results from excising the (K_i + 1)-th entry
from each of the N subsequences in the slot mapping (i.e. omitting the
decoded token's mapping.)
The N excised entries are appended to obtain the decode-phase slot mapping
Arguments:
* slot_mapping_list: Length-P 1D slot mapping (as List) reflecting all N
post-decode sequences
* seq_lens: List of N post-decode sequence lengths (K_i + 1 in the
description above)
* device: cuda, cpu, etc.
Returns:
* prefill_slot_mapping: Length-M 1D slot mapping (as Tensor)
reflecting all N prefill prompts
* decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting
all N decoded tokens
'''
prefill_slot_mapping
=
[]
decode_slot_mapping
=
[]
base_idx
=
0
for
seq_len
in
seq_lens
:
prefill_slot_mapping
.
extend
(
slot_mapping_list
[
base_idx
:(
base_idx
+
seq_len
-
1
)])
decode_slot_mapping
.
append
(
slot_mapping_list
[
base_idx
+
seq_len
-
1
])
base_idx
+=
seq_len
return
(
maybe_make_long_tensor
(
prefill_slot_mapping
,
device
),
maybe_make_long_tensor
(
decode_slot_mapping
,
device
))
def
make_block_tables_slot_mapping
(
block_size
:
int
,
seq_lens
:
List
[
int
],
device
:
Union
[
torch
.
device
,
str
],
block_base_addr
:
int
=
0
)
->
Tuple
[
torch
.
Tensor
,
List
[
int
],
int
]:
'''
Construct fake block tables & slot mappings.
For a sequence with num_tokens tokens the minimum number
of required KV cache blocks is
num_blocks = (num_tokens + block_size) // block_size
Then the minimum KV cache size in blocks is
total_cache_blocks = sum(num_blocks for all seqs)
Then, the blocktable mapping counts downward from
block_base_addr + total_cache_blocks
to
block_base_addr
The constructed block-tables and slot-mapping are sized to the
lengths of the sequences in their entirety (as reflected by seq_lens),
i.e. the total of prefill prompt tokens + decoded tokens.
Arguments:
* block_size: number of offsets per block
* seq_lens: list of token-counts for each sequence
* block_base_addr: the block table base address
* device: CPU or CUDA device
Return:
* block_tables_tensor: block table for sequence
* slot_mapping_list: slot mapping for sequence
* max_block_idx: the highest block address within this block table
'''
# Provision minimum number of KV cache blocks
num_blocks_list
=
[
_num_tokens_to_min_blocks
(
num_tokens
,
block_size
)
for
num_tokens
in
seq_lens
]
max_block_table_len
=
max
(
num_blocks_list
)
block_table_pad_tokens
=
10
block_tables
=
[]
slot_mapping_list
=
[]
# Compute uppermost address of block table
total_cache_blocks
=
sum
(
num_blocks_list
)
block_base_idx
=
block_base_addr
+
total_cache_blocks
max_block_idx
=
block_base_idx
for
sdx
,
num_tokens
in
enumerate
(
seq_lens
):
num_blocks
=
num_blocks_list
[
sdx
]
block_table
=
list
(
range
(
block_base_idx
,
block_base_idx
-
num_blocks
,
-
1
))
for
idx
in
range
(
num_tokens
):
mapping_value
=
(
idx
%
block_size
)
+
block_table
[
idx
//
block_size
]
*
block_size
slot_mapping_list
.
append
(
mapping_value
)
block_base_idx
-=
num_blocks
block_tables
.
append
(
block_table
)
block_tables_tensor
=
make_tensor_with_pad
(
block_tables
,
max_len
=
max_block_table_len
+
block_table_pad_tokens
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
device
,
)
return
(
block_tables_tensor
,
slot_mapping_list
,
max_block_idx
)
def
make_test_metadata
(
attn_backend
:
AttentionBackend
,
is_prompt
:
bool
,
seq_lens
:
Optional
[
List
[
int
]],
decoder_test_params
:
Optional
[
PhaseTestParameters
],
device
:
Union
[
torch
.
device
,
str
],
encoder_test_params
:
Optional
[
PhaseTestParameters
]
=
None
,
cross_test_params
:
Optional
[
PhaseTestParameters
]
=
None
)
->
AttentionMetadata
:
'''
Construct fake attention metadata for a given test phase
(prefill-phase or decode-phase).
encoder_test_params and cross_test_params arguments allow encoder
attention and enc/dec cross-attention (respectively) to use distinct
metadata values from decoder self-attention (decoder_test_params.)
if encoder_test_params and cross_test_params are None, the attention
metadata will support decoder-only scenario.
Assumptions:
* No chunked prefill -> a batch is 100% prefill or 100% decode, never both
Arguments:
* attn_backend: Backend for sourcing attention kernels
* is_prompt: prefill if True, o/w decode
* seq_lens: list of token counts for each sequence
* decoder_test_params: decoder self-attention test params;
this function requires
kv_mmap (memory mapping) field
* device: CPU or CUDA device
* encoder_test_params: encoder attention test params;
this function requires encoder query
sequence lengths field. If None,
encoder query sequence lengths are
treated as None
* cross_test_params: enc/dec cross-attention test params;
this function requires kv_mmap field.
If None, KV cache memory map data
structures are treated as None
Return:
* AttentionMetadata structure
'''
# Decoder self-attention memory mapping
# decoder_test_params is None signals encoder-only
# scenario, so kv_mmap is None
kv_mmap
=
(
None
if
decoder_test_params
is
None
else
decoder_test_params
.
kv_mmap
)
# This function constructs metadata assuming no chunked prefill,
# i.e. 100% prefill tokens or 100% decode tokens
#
# - If is_prompt, num_prefills_or_decodes is the number of prefills
# and num_prefill_or_decode_tokens is the number of prefill tokens
# - If not is_prompt, num_prefills_or_decodes is the number of decodes
# and num_prefill_or_decode_tokens is the number of decode tokens
#
# seq_lens is None signals encoder-only
# scenario, in which case num_prefills_or_decodes and
# num_prefill_or_decode_tokens are unused
num_prefills_or_decodes
=
(
None
if
seq_lens
is
None
else
len
(
seq_lens
))
num_prefill_or_decode_tokens
=
(
None
if
seq_lens
is
None
else
(
sum
(
seq_lens
)
if
is_prompt
else
len
(
seq_lens
)))
# Seems for non-prefix-caching scenarios context_lens
# is never needed
context_lens
=
None
if
encoder_test_params
is
None
:
encoder_seq_lens
=
None
num_encoder_tokens
=
None
else
:
# Encoder/decoder or encoder-only models only:
# * Extract encoder input sequence lengths
assert
encoder_test_params
.
packed_qkvo
.
packed_qkv
is
not
None
encoder_seq_lens
=
encoder_test_params
.
packed_qkvo
.
packed_qkv
.
q_seq_lens
num_encoder_tokens
=
(
None
if
encoder_seq_lens
is
None
else
(
sum
(
encoder_seq_lens
)))
if
cross_test_params
is
None
:
cross_kv_mmap
=
None
else
:
# Encoder/decoder or encoder-only models only:
# * Extract *cross-attention* slot_mapping and block table
# (kv_mmap)
cross_kv_mmap
=
cross_test_params
.
kv_mmap
if
is_prompt
:
# Prefill-phase scenario
num_prefills
=
num_prefills_or_decodes
num_prefill_tokens
=
num_prefill_or_decode_tokens
num_decode_tokens
=
0
(
seq_lens_tensor
,
context_lens_tensor
,
_
,
_
,
_
,
encoder_seq_lens_tensor
,
max_encoder_seq_len
,
)
=
_make_metadata_tensors
(
seq_lens
,
context_lens
,
encoder_seq_lens
,
device
=
device
)
return
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
slot_mapping
=
(
None
if
kv_mmap
is
None
else
kv_mmap
.
slot_mapping
),
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_prefill_seq_len
=
None
if
seq_lens
is
None
else
max
(
seq_lens
),
max_decode_seq_len
=
0
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
(
None
if
kv_mmap
is
None
else
kv_mmap
.
block_tables
),
use_cuda_graph
=
False
,
num_encoder_tokens
=
num_encoder_tokens
,
encoder_seq_lens
=
encoder_seq_lens
,
encoder_seq_lens_tensor
=
encoder_seq_lens_tensor
,
max_encoder_seq_len
=
max_encoder_seq_len
,
cross_slot_mapping
=
(
None
if
cross_kv_mmap
is
None
else
cross_kv_mmap
.
slot_mapping
),
cross_block_tables
=
(
None
if
cross_kv_mmap
is
None
else
cross_kv_mmap
.
block_tables
))
else
:
# not is_prompt
# Decode-phase scenario
assert
kv_mmap
is
not
None
assert
num_prefill_or_decode_tokens
is
not
None
assert
seq_lens
is
not
None
num_prefills
=
0
num_prefill_tokens
=
0
num_decode_tokens
=
num_prefill_or_decode_tokens
(
seq_lens_tensor
,
context_lens_tensor
,
_
,
_
,
_
,
encoder_seq_lens_tensor
,
max_encoder_seq_len
,
)
=
_make_metadata_tensors
(
seq_lens
,
context_lens
,
encoder_seq_lens
,
device
=
device
)
return
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
slot_mapping
=
kv_mmap
.
slot_mapping
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
max
(
seq_lens
),
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
kv_mmap
.
block_tables
,
use_cuda_graph
=
False
,
num_encoder_tokens
=
num_encoder_tokens
,
encoder_seq_lens
=
encoder_seq_lens
,
encoder_seq_lens_tensor
=
encoder_seq_lens_tensor
,
max_encoder_seq_len
=
max_encoder_seq_len
,
cross_slot_mapping
=
(
None
if
cross_kv_mmap
is
None
else
cross_kv_mmap
.
slot_mapping
),
cross_block_tables
=
(
None
if
cross_kv_mmap
is
None
else
cross_kv_mmap
.
block_tables
))
def
assert_actual_matches_ideal
(
test_params
:
PhaseTestParameters
,
output_under_test
:
torch
.
Tensor
)
->
None
:
'''
Assert that observed output matches the ideal output
contained in the test parameters data structure.
Arguments:
* test_params: Test parameters including packed ideal output
* output_under_test: actually observed output value
'''
ideal_output
=
test_params
.
packed_qkvo
.
ideal_output
assert
torch
.
allclose
(
ideal_output
,
output_under_test
.
view_as
(
ideal_output
))
tests/lora/conftest.py
View file @
e7c1b7f3
...
@@ -2,6 +2,7 @@ import contextlib
...
@@ -2,6 +2,7 @@ import contextlib
import
gc
import
gc
import
tempfile
import
tempfile
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Dict
,
List
,
TypedDict
from
unittest.mock
import
MagicMock
,
patch
from
unittest.mock
import
MagicMock
,
patch
import
pytest
import
pytest
...
@@ -24,7 +25,18 @@ from vllm.model_executor.layers.sampler import Sampler
...
@@ -24,7 +25,18 @@ from vllm.model_executor.layers.sampler import Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
LONG_LORA_INFOS
=
[{
class
ContextIDInfo
(
TypedDict
):
lora_id
:
int
context_length
:
str
class
ContextInfo
(
TypedDict
):
lora
:
str
context_length
:
str
LONG_LORA_INFOS
:
List
[
ContextIDInfo
]
=
[{
"lora_id"
:
1
,
"lora_id"
:
1
,
"context_length"
:
"16k"
,
"context_length"
:
"16k"
,
},
{
},
{
...
@@ -147,13 +159,21 @@ def dummy_model_gate_up() -> nn.Module:
...
@@ -147,13 +159,21 @@ def dummy_model_gate_up() -> nn.Module:
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
sql_lora_files
():
def
sql_lora_huggingface_id
():
return
snapshot_download
(
repo_id
=
"yard1/llama-2-7b-sql-lora-test"
)
# huggingface repo id is used to test lora runtime downloading.
return
"yard1/llama-2-7b-sql-lora-test"
@
pytest
.
fixture
(
scope
=
"session"
)
def
sql_lora_files
(
sql_lora_huggingface_id
):
return
snapshot_download
(
repo_id
=
sql_lora_huggingface_id
)
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
mixtral_lora_files
():
def
mixtral_lora_files
():
return
snapshot_download
(
repo_id
=
"terrysun/mixtral-lora-adapter"
)
# Note: this module has incorrect adapter_config.json to test
# https://github.com/vllm-project/vllm/pull/5909/files.
return
snapshot_download
(
repo_id
=
"SangBinCho/mixtral-lora"
)
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
...
@@ -207,7 +227,7 @@ def long_context_infos(long_context_lora_files_16k_1,
...
@@ -207,7 +227,7 @@ def long_context_infos(long_context_lora_files_16k_1,
long_context_lora_files_16k_2
,
long_context_lora_files_16k_2
,
long_context_lora_files_32k
):
long_context_lora_files_32k
):
cleanup
()
cleanup
()
infos
=
{}
infos
:
Dict
[
int
,
ContextInfo
]
=
{}
for
lora_checkpoint_info
in
LONG_LORA_INFOS
:
for
lora_checkpoint_info
in
LONG_LORA_INFOS
:
lora_id
=
lora_checkpoint_info
[
"lora_id"
]
lora_id
=
lora_checkpoint_info
[
"lora_id"
]
if
lora_id
==
1
:
if
lora_id
==
1
:
...
@@ -226,7 +246,7 @@ def long_context_infos(long_context_lora_files_16k_1,
...
@@ -226,7 +246,7 @@ def long_context_infos(long_context_lora_files_16k_1,
@
pytest
.
fixture
@
pytest
.
fixture
def
llama_2_7b_engine_extra_embeddings
()
->
nn
.
Module
:
def
llama_2_7b_engine_extra_embeddings
():
cleanup
()
cleanup
()
get_model_old
=
get_model
get_model_old
=
get_model
...
@@ -244,7 +264,6 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
...
@@ -244,7 +264,6 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
@
pytest
.
fixture
@
pytest
.
fixture
def
llama_2_7b_model_extra_embeddings
(
def
llama_2_7b_model_extra_embeddings
(
llama_2_7b_engine_extra_embeddings
):
llama_2_7b_engine_extra_embeddings
)
->
nn
.
Module
:
yield
(
llama_2_7b_engine_extra_embeddings
.
model_executor
.
driver_worker
.
yield
(
llama_2_7b_engine_extra_embeddings
.
model_executor
.
driver_worker
.
model_runner
.
model
)
model_runner
.
model
)
Prev
1
…
12
13
14
15
16
17
18
19
20
…
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment