Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7a985548
Commit
7a985548
authored
May 22, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.9.0' into v0.9.0-ori
parents
45d3785c
dc1440cf
Changes
486
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1406 additions
and
254 deletions
+1406
-254
tests/entrypoints/openai/test_completion_with_prompt_embeds.py
.../entrypoints/openai/test_completion_with_prompt_embeds.py
+257
-0
tests/entrypoints/openai/test_embedding.py
tests/entrypoints/openai/test_embedding.py
+6
-6
tests/entrypoints/openai/test_embedding_dimensions.py
tests/entrypoints/openai/test_embedding_dimensions.py
+3
-2
tests/entrypoints/openai/test_openai_schema.py
tests/entrypoints/openai/test_openai_schema.py
+2
-2
tests/entrypoints/openai/test_serving_chat.py
tests/entrypoints/openai/test_serving_chat.py
+40
-0
tests/entrypoints/openai/test_tokenization.py
tests/entrypoints/openai/test_tokenization.py
+77
-0
tests/entrypoints/openai/test_truncation.py
tests/entrypoints/openai/test_truncation.py
+103
-0
tests/entrypoints/openai/tool_parsers/utils.py
tests/entrypoints/openai/tool_parsers/utils.py
+1
-1
tests/entrypoints/test_chat_utils.py
tests/entrypoints/test_chat_utils.py
+97
-16
tests/kernels/attention/test_attention_selector.py
tests/kernels/attention/test_attention_selector.py
+7
-3
tests/kernels/attention/test_flashmla.py
tests/kernels/attention/test_flashmla.py
+1
-1
tests/kernels/attention/test_rocm_attention_selector.py
tests/kernels/attention/test_rocm_attention_selector.py
+4
-2
tests/kernels/attention/test_triton_unified_attention.py
tests/kernels/attention/test_triton_unified_attention.py
+192
-0
tests/kernels/core/test_pos_encoding.py
tests/kernels/core/test_pos_encoding.py
+51
-16
tests/kernels/core/test_rotary_embedding.py
tests/kernels/core/test_rotary_embedding.py
+18
-4
tests/kernels/mamba/test_mamba_ssm_ssd.py
tests/kernels/mamba/test_mamba_ssm_ssd.py
+4
-3
tests/kernels/moe/test_batched_moe.py
tests/kernels/moe/test_batched_moe.py
+114
-0
tests/kernels/moe/test_cutlass_moe.py
tests/kernels/moe/test_cutlass_moe.py
+21
-25
tests/kernels/moe/test_moe.py
tests/kernels/moe/test_moe.py
+185
-173
tests/kernels/moe/test_moe_permute_unpermute.py
tests/kernels/moe/test_moe_permute_unpermute.py
+223
-0
No files found.
Too many changes to show.
To preserve performance only
486 of 486+
files are displayed.
Plain diff
Email patch
tests/entrypoints/openai/test_completion_with_prompt_embeds.py
0 → 100644
View file @
7a985548
# SPDX-License-Identifier: Apache-2.0
import
base64
import
io
import
shutil
from
tempfile
import
TemporaryDirectory
import
openai
# use the official client for correctness check
import
pytest
import
pytest_asyncio
import
torch
# downloading lora to test lora requests
from
huggingface_hub
import
snapshot_download
from
openai
import
BadRequestError
from
transformers
import
AutoConfig
,
AutoTokenizer
from
...utils
import
RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME
=
"HuggingFaceH4/zephyr-7b-beta"
LORA_NAME
=
"typeof/zephyr-7b-beta-lora"
CONFIG
=
AutoConfig
.
from_pretrained
(
MODEL_NAME
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
zephyr_lora_files
():
return
snapshot_download
(
repo_id
=
LORA_NAME
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
zephyr_lora_added_tokens_files
(
zephyr_lora_files
):
tmp_dir
=
TemporaryDirectory
()
tmp_model_dir
=
f
"
{
tmp_dir
.
name
}
/zephyr"
shutil
.
copytree
(
zephyr_lora_files
,
tmp_model_dir
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MODEL_NAME
)
# Copy tokenizer to adapter and add some unique tokens
# 32000, 32001, 32002
added
=
tokenizer
.
add_tokens
([
"vllm1"
,
"vllm2"
,
"vllm3"
],
special_tokens
=
True
)
assert
added
==
3
tokenizer
.
save_pretrained
(
tmp_model_dir
)
yield
tmp_model_dir
tmp_dir
.
cleanup
()
@
pytest
.
fixture
(
scope
=
"module"
)
def
default_server_args
(
zephyr_lora_files
,
zephyr_lora_added_tokens_files
,
)
->
list
[
str
]:
return
[
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"bfloat16"
,
"--max-model-len"
,
"8192"
,
"--max-num-seqs"
,
"128"
,
"--enforce-eager"
,
# Prompt Embeds server args
"--enable-prompt-embeds"
,
"--no-enable-chunked-prefill"
,
]
@
pytest
.
fixture
(
scope
=
"module"
,
params
=
[
""
,
"--disable-frontend-multiprocessing"
])
def
server_with_prompt_embeds
(
default_server_args
,
request
):
if
request
.
param
:
default_server_args
.
append
(
request
.
param
)
with
RemoteOpenAIServer
(
MODEL_NAME
,
default_server_args
)
as
remote_server
:
yield
remote_server
@
pytest_asyncio
.
fixture
async
def
client_with_prompt_embeds
(
server_with_prompt_embeds
):
async
with
server_with_prompt_embeds
.
get_async_client
()
as
async_client
:
yield
async_client
def
create_dummy_embeds
(
num_tokens
:
int
=
5
)
->
str
:
"""Create dummy embeddings and return them as base64 encoded string."""
dummy_embeds
=
torch
.
randn
(
num_tokens
,
CONFIG
.
hidden_size
)
buffer
=
io
.
BytesIO
()
torch
.
save
(
dummy_embeds
,
buffer
)
return
base64
.
b64encode
(
buffer
.
getvalue
()).
decode
(
'utf-8'
)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
async
def
test_completions_with_prompt_embeds
(
client_with_prompt_embeds
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
# Test case: Single prompt embeds input
encoded_embeds
=
create_dummy_embeds
()
completion
=
await
client_with_prompt_embeds
.
completions
.
create
(
model
=
model_name
,
prompt
=
""
,
# Add empty prompt as required parameter
max_tokens
=
5
,
temperature
=
0.0
,
extra_body
=
{
"prompt_embeds"
:
encoded_embeds
})
assert
len
(
completion
.
choices
[
0
].
text
)
>=
1
assert
completion
.
choices
[
0
].
prompt_logprobs
is
None
# Test case: batch completion with prompt_embeds
encoded_embeds2
=
create_dummy_embeds
()
completion
=
await
client_with_prompt_embeds
.
completions
.
create
(
model
=
model_name
,
prompt
=
""
,
# Add empty prompt as required parameter
max_tokens
=
5
,
temperature
=
0.0
,
extra_body
=
{
"prompt_embeds"
:
[
encoded_embeds
,
encoded_embeds2
]})
assert
len
(
completion
.
choices
)
==
2
assert
len
(
completion
.
choices
[
0
].
text
)
>=
1
assert
len
(
completion
.
choices
[
1
].
text
)
>=
1
# Test case: streaming with prompt_embeds
encoded_embeds
=
create_dummy_embeds
()
single_completion
=
await
client_with_prompt_embeds
.
completions
.
create
(
model
=
model_name
,
prompt
=
""
,
# Add empty prompt as required parameter
max_tokens
=
5
,
temperature
=
0.0
,
extra_body
=
{
"prompt_embeds"
:
encoded_embeds
})
single_output
=
single_completion
.
choices
[
0
].
text
stream
=
await
client_with_prompt_embeds
.
completions
.
create
(
model
=
model_name
,
prompt
=
""
,
# Add empty prompt as required parameter
max_tokens
=
5
,
temperature
=
0.0
,
stream
=
True
,
extra_body
=
{
"prompt_embeds"
:
encoded_embeds
})
chunks
=
[]
finish_reason_count
=
0
async
for
chunk
in
stream
:
chunks
.
append
(
chunk
.
choices
[
0
].
text
)
if
chunk
.
choices
[
0
].
finish_reason
is
not
None
:
finish_reason_count
+=
1
assert
finish_reason_count
==
1
assert
chunk
.
choices
[
0
].
finish_reason
==
"length"
assert
chunk
.
choices
[
0
].
text
assert
""
.
join
(
chunks
)
==
single_output
# Test case: batch streaming with prompt_embeds
encoded_embeds2
=
create_dummy_embeds
()
stream
=
await
client_with_prompt_embeds
.
completions
.
create
(
model
=
model_name
,
prompt
=
""
,
# Add empty prompt as required parameter
max_tokens
=
5
,
temperature
=
0.0
,
stream
=
True
,
extra_body
=
{
"prompt_embeds"
:
[
encoded_embeds
,
encoded_embeds2
]})
chunks_stream_embeds
:
list
[
list
[
str
]]
=
[[],
[]]
finish_reason_count
=
0
async
for
chunk
in
stream
:
chunks_stream_embeds
[
chunk
.
choices
[
0
].
index
].
append
(
chunk
.
choices
[
0
].
text
)
if
chunk
.
choices
[
0
].
finish_reason
is
not
None
:
finish_reason_count
+=
1
assert
finish_reason_count
==
2
assert
chunk
.
choices
[
0
].
finish_reason
==
"length"
assert
chunk
.
choices
[
0
].
text
assert
len
(
chunks_stream_embeds
[
0
])
>
0
assert
len
(
chunks_stream_embeds
[
1
])
>
0
# Test case: mixed text and prompt_embeds
encoded_embeds
=
create_dummy_embeds
()
completion_mixed
=
await
client_with_prompt_embeds
.
completions
.
create
(
model
=
model_name
,
prompt
=
"This is a prompt"
,
max_tokens
=
5
,
temperature
=
0.0
,
extra_body
=
{
"prompt_embeds"
:
encoded_embeds
})
assert
len
(
completion
.
choices
)
==
2
completion_text_only
=
await
client_with_prompt_embeds
.
completions
.
create
(
model
=
model_name
,
prompt
=
"This is a prompt"
,
max_tokens
=
5
,
temperature
=
0.0
,
)
completion_embeds_only
=
await
client_with_prompt_embeds
.
completions
.
create
(
model
=
model_name
,
prompt
=
""
,
max_tokens
=
5
,
temperature
=
0.0
,
extra_body
=
{
"prompt_embeds"
:
encoded_embeds
})
# Embeddings responses should be handled first
assert
completion_mixed
.
choices
[
0
].
text
==
completion_embeds_only
.
choices
[
0
].
text
assert
completion_mixed
.
choices
[
1
].
text
==
completion_text_only
.
choices
[
0
].
text
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
async
def
test_completions_errors_with_prompt_embeds
(
client_with_prompt_embeds
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
# Test error case: invalid prompt_embeds
with
pytest
.
raises
(
BadRequestError
):
await
client_with_prompt_embeds
.
completions
.
create
(
prompt
=
""
,
model
=
model_name
,
max_tokens
=
5
,
temperature
=
0.0
,
extra_body
=
{
"prompt_embeds"
:
"invalid_base64"
})
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"logprobs_arg"
,
[
1
,
0
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
async
def
test_completions_with_logprobs_and_prompt_embeds
(
client_with_prompt_embeds
:
openai
.
AsyncOpenAI
,
logprobs_arg
:
int
,
model_name
:
str
):
# Test case: Logprobs using prompt_embeds
encoded_embeds
=
create_dummy_embeds
()
completion
=
await
client_with_prompt_embeds
.
completions
.
create
(
model
=
model_name
,
prompt
=
""
,
# Add empty prompt as required parameter
max_tokens
=
5
,
temperature
=
0.0
,
echo
=
False
,
logprobs
=
logprobs_arg
,
extra_body
=
{
"prompt_embeds"
:
encoded_embeds
})
logprobs
=
completion
.
choices
[
0
].
logprobs
assert
logprobs
is
not
None
assert
len
(
logprobs
.
text_offset
)
==
5
assert
len
(
logprobs
.
token_logprobs
)
==
5
assert
len
(
logprobs
.
top_logprobs
)
==
5
for
top_logprobs
in
logprobs
.
top_logprobs
[
1
:]:
assert
max
(
logprobs_arg
,
1
)
<=
len
(
top_logprobs
)
<=
logprobs_arg
+
1
assert
len
(
logprobs
.
tokens
)
==
5
# Test case: Log probs with batch completion and prompt_embeds
encoded_embeds2
=
create_dummy_embeds
()
completion
=
await
client_with_prompt_embeds
.
completions
.
create
(
model
=
model_name
,
prompt
=
""
,
# Add empty prompt as required parameter
max_tokens
=
5
,
temperature
=
0.0
,
echo
=
False
,
logprobs
=
logprobs_arg
,
extra_body
=
{
"prompt_embeds"
:
[
encoded_embeds
,
encoded_embeds2
]})
assert
len
(
completion
.
choices
)
==
2
for
choice
in
completion
.
choices
:
logprobs
=
choice
.
logprobs
assert
logprobs
is
not
None
assert
len
(
logprobs
.
text_offset
)
==
5
assert
len
(
logprobs
.
token_logprobs
)
==
5
assert
len
(
logprobs
.
top_logprobs
)
==
5
for
top_logprobs
in
logprobs
.
top_logprobs
[
1
:]:
assert
max
(
logprobs_arg
,
1
)
<=
len
(
top_logprobs
)
<=
logprobs_arg
+
1
assert
len
(
logprobs
.
tokens
)
==
5
tests/entrypoints/openai/test_embedding.py
View file @
7a985548
...
@@ -11,7 +11,7 @@ import requests
...
@@ -11,7 +11,7 @@ import requests
from
vllm.entrypoints.openai.protocol
import
EmbeddingResponse
from
vllm.entrypoints.openai.protocol
import
EmbeddingResponse
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
...models.
embedding.
utils
import
correctness_test
from
...models.utils
import
run_embedding_
correctness_test
from
...utils
import
RemoteOpenAIServer
from
...utils
import
RemoteOpenAIServer
MODEL_NAME
=
"intfloat/multilingual-e5-small"
MODEL_NAME
=
"intfloat/multilingual-e5-small"
...
@@ -76,7 +76,7 @@ async def test_single_embedding(hf_model, client: openai.AsyncOpenAI,
...
@@ -76,7 +76,7 @@ async def test_single_embedding(hf_model, client: openai.AsyncOpenAI,
assert
embeddings
.
usage
.
total_tokens
==
11
assert
embeddings
.
usage
.
total_tokens
==
11
vllm_outputs
=
[
d
.
embedding
for
d
in
embeddings
.
data
]
vllm_outputs
=
[
d
.
embedding
for
d
in
embeddings
.
data
]
correctness_test
(
hf_model
,
input_texts
,
vllm_outputs
)
run_embedding_
correctness_test
(
hf_model
,
input_texts
,
vllm_outputs
)
# test using token IDs
# test using token IDs
input_tokens
=
[
1
,
1
,
1
,
1
,
1
]
input_tokens
=
[
1
,
1
,
1
,
1
,
1
]
...
@@ -121,7 +121,7 @@ async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI,
...
@@ -121,7 +121,7 @@ async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI,
assert
embeddings
.
usage
.
total_tokens
==
33
assert
embeddings
.
usage
.
total_tokens
==
33
vllm_outputs
=
[
d
.
embedding
for
d
in
embeddings
.
data
]
vllm_outputs
=
[
d
.
embedding
for
d
in
embeddings
.
data
]
correctness_test
(
hf_model
,
input_texts
,
vllm_outputs
)
run_embedding_
correctness_test
(
hf_model
,
input_texts
,
vllm_outputs
)
# test list[list[int]]
# test list[list[int]]
input_tokens
=
[[
4
,
5
,
7
,
9
,
20
],
[
15
,
29
,
499
],
[
24
,
24
,
24
,
24
,
24
],
input_tokens
=
[[
4
,
5
,
7
,
9
,
20
],
[
15
,
29
,
499
],
[
24
,
24
,
24
,
24
,
24
],
...
@@ -208,7 +208,7 @@ async def test_batch_base64_embedding(hf_model, client: openai.AsyncOpenAI,
...
@@ -208,7 +208,7 @@ async def test_batch_base64_embedding(hf_model, client: openai.AsyncOpenAI,
model
=
model_name
,
model
=
model_name
,
encoding_format
=
"float"
)
encoding_format
=
"float"
)
float_data
=
[
d
.
embedding
for
d
in
responses_float
.
data
]
float_data
=
[
d
.
embedding
for
d
in
responses_float
.
data
]
correctness_test
(
hf_model
,
input_texts
,
float_data
)
run_embedding_
correctness_test
(
hf_model
,
input_texts
,
float_data
)
responses_base64
=
await
client
.
embeddings
.
create
(
input
=
input_texts
,
responses_base64
=
await
client
.
embeddings
.
create
(
input
=
input_texts
,
model
=
model_name
,
model
=
model_name
,
...
@@ -219,13 +219,13 @@ async def test_batch_base64_embedding(hf_model, client: openai.AsyncOpenAI,
...
@@ -219,13 +219,13 @@ async def test_batch_base64_embedding(hf_model, client: openai.AsyncOpenAI,
np
.
frombuffer
(
base64
.
b64decode
(
data
.
embedding
),
np
.
frombuffer
(
base64
.
b64decode
(
data
.
embedding
),
dtype
=
"float32"
).
tolist
())
dtype
=
"float32"
).
tolist
())
correctness_test
(
hf_model
,
input_texts
,
base64_data
)
run_embedding_
correctness_test
(
hf_model
,
input_texts
,
base64_data
)
# Default response is float32 decoded from base64 by OpenAI Client
# Default response is float32 decoded from base64 by OpenAI Client
responses_default
=
await
client
.
embeddings
.
create
(
input
=
input_texts
,
responses_default
=
await
client
.
embeddings
.
create
(
input
=
input_texts
,
model
=
model_name
)
model
=
model_name
)
default_data
=
[
d
.
embedding
for
d
in
responses_default
.
data
]
default_data
=
[
d
.
embedding
for
d
in
responses_default
.
data
]
correctness_test
(
hf_model
,
input_texts
,
default_data
)
run_embedding_
correctness_test
(
hf_model
,
input_texts
,
default_data
)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
...
...
tests/entrypoints/openai/test_embedding_dimensions.py
View file @
7a985548
...
@@ -11,7 +11,7 @@ import pytest
...
@@ -11,7 +11,7 @@ import pytest
from
vllm.entrypoints.openai.protocol
import
EmbeddingResponse
from
vllm.entrypoints.openai.protocol
import
EmbeddingResponse
from
...conftest
import
HfRunner
from
...conftest
import
HfRunner
from
...models.
embedding.
utils
import
EmbedModelInfo
,
correctness_test
from
...models.utils
import
EmbedModelInfo
,
run_embedding_
correctness_test
from
...utils
import
RemoteOpenAIServer
from
...utils
import
RemoteOpenAIServer
MODELS
=
[
MODELS
=
[
...
@@ -95,7 +95,8 @@ async def test_matryoshka(model_info: EmbedModelInfo,
...
@@ -95,7 +95,8 @@ async def test_matryoshka(model_info: EmbedModelInfo,
assert
len
(
embeddings
.
data
[
0
].
embedding
)
==
dimensions
assert
len
(
embeddings
.
data
[
0
].
embedding
)
==
dimensions
vllm_outputs
=
[
d
.
embedding
for
d
in
embeddings
.
data
]
vllm_outputs
=
[
d
.
embedding
for
d
in
embeddings
.
data
]
correctness_test
(
hf_model
,
prompts
,
vllm_outputs
,
dimensions
)
run_embedding_correctness_test
(
hf_model
,
prompts
,
vllm_outputs
,
dimensions
)
if
model_info
.
is_matryoshka
:
if
model_info
.
is_matryoshka
:
valid_dimensions
:
list
[
Optional
[
int
]]
=
[
None
]
valid_dimensions
:
list
[
Optional
[
int
]]
=
[
None
]
...
...
tests/entrypoints/openai/test_openai_schema.py
View file @
7a985548
...
@@ -44,6 +44,6 @@ schema = schemathesis.from_pytest_fixture("get_schema")
...
@@ -44,6 +44,6 @@ schema = schemathesis.from_pytest_fixture("get_schema")
@
schema
.
parametrize
()
@
schema
.
parametrize
()
@
schema
.
override
(
headers
=
{
"Content-Type"
:
"application/json"
})
@
schema
.
override
(
headers
=
{
"Content-Type"
:
"application/json"
})
async
def
test_openapi_stateless
(
case
):
def
test_openapi_stateless
(
case
:
schemathesis
.
Case
):
#No need to verify SSL certificate for localhost
#No need to verify SSL certificate for localhost
await
case
.
call_and_validate
(
verify
=
False
)
case
.
call_and_validate
(
verify
=
False
)
tests/entrypoints/openai/test_serving_chat.py
View file @
7a985548
...
@@ -272,3 +272,43 @@ def test_serving_chat_could_load_correct_generation_config():
...
@@ -272,3 +272,43 @@ def test_serving_chat_could_load_correct_generation_config():
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
temperature
==
0.0
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
temperature
==
0.0
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
repetition_penalty
==
1.05
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
repetition_penalty
==
1.05
def
test_serving_chat_did_set_correct_cache_salt
():
mock_model_config
=
MockModelConfig
()
mock_engine
=
MagicMock
(
spec
=
MQLLMEngineClient
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
mock_engine
.
errored
=
False
# Initialize the serving chat
models
=
OpenAIServingModels
(
engine_client
=
mock_engine
,
base_model_paths
=
BASE_MODEL_PATHS
,
model_config
=
mock_model_config
)
serving_chat
=
OpenAIServingChat
(
mock_engine
,
mock_model_config
,
models
,
response_role
=
"assistant"
,
chat_template
=
CHAT_TEMPLATE
,
chat_template_content_format
=
"auto"
,
request_logger
=
None
)
# Test cache_salt
req
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"what is 1+1?"
}],
)
# By default cache_salt in the engine prompt is not set
with
suppress
(
Exception
):
asyncio
.
run
(
serving_chat
.
create_chat_completion
(
req
))
assert
"cache_salt"
not
in
mock_engine
.
generate
.
call_args
.
args
[
0
]
# Test with certain cache_salt
req
.
cache_salt
=
"test_salt"
with
suppress
(
Exception
):
asyncio
.
run
(
serving_chat
.
create_chat_completion
(
req
))
assert
mock_engine
.
generate
.
call_args
.
args
[
0
][
"cache_salt"
]
==
"test_salt"
tests/entrypoints/openai/test_tokenization.py
View file @
7a985548
...
@@ -145,6 +145,83 @@ async def test_tokenize_chat(
...
@@ -145,6 +145,83 @@ async def test_tokenize_chat(
}
}
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name,tokenizer_name"
,
[(
MODEL_NAME
,
MODEL_NAME
),
(
"zephyr-lora2"
,
"zephyr-lora2"
)],
indirect
=
[
"tokenizer_name"
],
)
async
def
test_tokenize_chat_with_tools
(
server
:
RemoteOpenAIServer
,
model_name
:
str
,
tokenizer_name
:
str
,
):
tokenizer
=
get_tokenizer
(
tokenizer_name
=
tokenizer_name
,
tokenizer_mode
=
"fast"
)
for
add_generation
in
[
False
,
True
]:
for
add_special
in
[
False
,
True
]:
conversation
=
[{
"role"
:
"user"
,
"content"
:
"What's the weather like in Paris today?"
,
}]
tools
=
[{
"type"
:
"function"
,
"function"
:
{
"name"
:
"get_weather"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"location"
:
{
"type"
:
"string"
}
},
},
},
}]
for
continue_final
in
[
False
,
True
]:
if
add_generation
and
continue_final
:
continue
if
continue_final
:
conversation
.
append
({
"role"
:
"assistant"
,
"content"
:
"Sure,"
})
prompt
=
tokenizer
.
apply_chat_template
(
add_generation_prompt
=
add_generation
,
continue_final_message
=
continue_final
,
conversation
=
conversation
,
tools
=
tools
,
tokenize
=
False
,
)
tokens
=
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
add_special
)
response
=
requests
.
post
(
server
.
url_for
(
"tokenize"
),
json
=
{
"add_generation_prompt"
:
add_generation
,
"continue_final_message"
:
continue_final
,
"add_special_tokens"
:
add_special
,
"messages"
:
conversation
,
"model"
:
model_name
,
"tools"
:
tools
,
},
)
response
.
raise_for_status
()
assert
response
.
json
()
==
{
"tokens"
:
tokens
,
"count"
:
len
(
tokens
),
"max_model_len"
:
8192
,
}
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"model_name,tokenizer_name"
,
"model_name,tokenizer_name"
,
...
...
tests/entrypoints/openai/test_truncation.py
0 → 100644
View file @
7a985548
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
import
openai
import
pytest
import
pytest_asyncio
from
tests.utils
import
RemoteOpenAIServer
MODEL_NAME
=
"sentence-transformers/all-MiniLM-L12-v2"
max_model_len
=
128
input
=
"""Immerse yourself in the enchanting chronicle of calculus, a
mathematical domain that has radically transformed our comprehension of
change and motion. Despite its roots in ancient civilizations, the
formal birth of calculus predominantly occurred in the 17th century,
primarily under the influential guidance of Sir Isaac Newton and Gottfried
Wilhelm Leibniz. The earliest traces of calculus concepts are found in
ancient Greek mathematics,most notably in the works of Eudoxus and
Archimedes, around 300 BCE. They utilized the 'method of exhaustion'—a
technique for computing areas and volumes through the use of finite sums.
This methodology laid crucial foundational work for integral calculus.
In the 17th century, both Newton and Leibniz independently pioneered
calculus, each contributing unique perspectives that would shape this new
field."""
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
():
args
=
[
"--task"
,
"embed"
,
"--dtype"
,
"bfloat16"
,
"--enforce-eager"
,
"--max-model-len"
,
str
(
max_model_len
),
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
yield
remote_server
@
pytest_asyncio
.
fixture
async
def
client
(
server
):
async
with
server
.
get_async_client
()
as
async_client
:
yield
async_client
@
pytest
.
mark
.
asyncio
async
def
test_smaller_truncation_size
(
client
:
openai
.
AsyncOpenAI
):
truncation_size
=
10
kwargs
:
dict
[
str
,
Any
]
=
{
"model"
:
MODEL_NAME
,
"input"
:
input
,
"truncate_prompt_tokens"
:
truncation_size
}
response
=
await
client
.
post
(
path
=
"embeddings"
,
cast_to
=
object
,
body
=
{
**
kwargs
})
assert
response
[
"usage"
][
"prompt_tokens"
]
==
truncation_size
@
pytest
.
mark
.
asyncio
async
def
test_bigger_truncation_size
(
client
:
openai
.
AsyncOpenAI
):
truncation_size
=
max_model_len
+
1
kwargs
:
dict
[
str
,
Any
]
=
{
"model"
:
MODEL_NAME
,
"input"
:
input
,
"truncate_prompt_tokens"
:
truncation_size
}
with
pytest
.
raises
(
openai
.
BadRequestError
)
as
err
:
err
=
await
client
.
post
(
path
=
"embeddings"
,
cast_to
=
object
,
body
=
{
**
kwargs
})
assert
str
(
err
)
==
f
"""openai.BadRequestError:
Error code: 400 - {{'object': 'error',
'message': 'truncate_prompt_tokens value
(
{
truncation_size
}
)
is greater than max_model_len (
{
max_model_len
}
).
Please, select a smaller truncation size.',
'type': 'BadRequestError',
'param': None, 'code': 400}}"""
@
pytest
.
mark
.
asyncio
async
def
test_max_truncation_size
(
client
:
openai
.
AsyncOpenAI
):
truncation_size
=
-
1
kwargs
:
dict
[
str
,
Any
]
=
{
"model"
:
MODEL_NAME
,
"input"
:
input
,
"truncate_prompt_tokens"
:
truncation_size
}
response
=
await
client
.
post
(
path
=
"embeddings"
,
cast_to
=
object
,
body
=
{
**
kwargs
})
assert
response
[
"usage"
][
"prompt_tokens"
]
==
max_model_len
tests/entrypoints/openai/tool_parsers/utils.py
View file @
7a985548
...
@@ -32,7 +32,7 @@ class StreamingToolReconstructor:
...
@@ -32,7 +32,7 @@ class StreamingToolReconstructor:
assert
len
(
delta
.
tool_calls
)
<
2
,
(
assert
len
(
delta
.
tool_calls
)
<
2
,
(
"Streaming should include only one tool call per update."
)
"Streaming should include only one tool call per update."
)
for
call_delta
in
delta
.
tool_calls
:
for
call_delta
in
delta
.
tool_calls
:
assert
call_delta
.
type
==
"function"
,
(
assert
call_delta
.
type
is
None
or
call_delta
.
type
==
"function"
,
(
"Streaming tool calls should only emit function calls. Got "
"Streaming tool calls should only emit function calls. Got "
f
"
{
call_delta
.
type
}
"
)
f
"
{
call_delta
.
type
}
"
)
current_tool_call
=
self
.
tool_calls
[
current_tool_call
=
self
.
tool_calls
[
...
...
tests/entrypoints/test_chat_utils.py
View file @
7a985548
...
@@ -4,8 +4,6 @@ import warnings
...
@@ -4,8 +4,6 @@ import warnings
from
typing
import
Optional
from
typing
import
Optional
import
pytest
import
pytest
from
packaging.version
import
Version
from
transformers
import
__version__
as
TRANSFORMERS_VERSION
from
vllm.assets.image
import
ImageAsset
from
vllm.assets.image
import
ImageAsset
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
...
@@ -19,6 +17,7 @@ from vllm.multimodal import MultiModalDataDict
...
@@ -19,6 +17,7 @@ from vllm.multimodal import MultiModalDataDict
from
vllm.multimodal.utils
import
encode_image_base64
from
vllm.multimodal.utils
import
encode_image_base64
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
..models.registry
import
HF_EXAMPLE_MODELS
from
..utils
import
VLLM_PATH
from
..utils
import
VLLM_PATH
EXAMPLES_DIR
=
VLLM_PATH
/
"examples"
EXAMPLES_DIR
=
VLLM_PATH
/
"examples"
...
@@ -772,6 +771,7 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
...
@@ -772,6 +771,7 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
enable_lora
=
False
,
enable_lora
=
False
,
max_num_seqs
=
5
,
max_num_seqs
=
5
,
max_input_length
=
None
,
max_input_length
=
None
,
trust_remote_code
=
model_config
.
trust_remote_code
,
)
)
tokenizer
=
tokenizer_group
.
tokenizer
tokenizer
=
tokenizer_group
.
tokenizer
...
@@ -793,10 +793,10 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
...
@@ -793,10 +793,10 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
)
)
vllm_result
=
apply_hf_chat_template
(
vllm_result
=
apply_hf_chat_template
(
tokenizer
,
tokenizer
=
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
,
conversation
=
conversation
,
conversation
=
conversation
,
chat_template
=
None
,
chat_template
=
None
,
model_config
=
model_config
,
tools
=
None
,
tools
=
None
,
add_generation_prompt
=
True
,
add_generation_prompt
=
True
,
)
)
...
@@ -813,6 +813,16 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
...
@@ -813,6 +813,16 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
@
pytest
.
mark
.
parametrize
(
"use_tools"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_tools"
,
[
True
,
False
])
def
test_resolve_hf_chat_template
(
sample_json_schema
,
model
,
use_tools
):
def
test_resolve_hf_chat_template
(
sample_json_schema
,
model
,
use_tools
):
"""checks that chat_template is a dict type for HF models."""
"""checks that chat_template is a dict type for HF models."""
model_info
=
HF_EXAMPLE_MODELS
.
find_hf_info
(
model
)
model_info
.
check_available_online
(
on_fail
=
"skip"
)
model_config
=
ModelConfig
(
model
,
tokenizer
=
model_info
.
tokenizer
or
model
,
tokenizer_mode
=
model_info
.
tokenizer_mode
,
trust_remote_code
=
model_info
.
trust_remote_code
,
hf_overrides
=
model_info
.
hf_overrides
,
)
# Build the tokenizer group and grab the underlying tokenizer
# Build the tokenizer group and grab the underlying tokenizer
tokenizer_group
=
TokenizerGroup
(
tokenizer_group
=
TokenizerGroup
(
...
@@ -820,6 +830,7 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
...
@@ -820,6 +830,7 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
enable_lora
=
False
,
enable_lora
=
False
,
max_num_seqs
=
5
,
max_num_seqs
=
5
,
max_input_length
=
None
,
max_input_length
=
None
,
trust_remote_code
=
model_config
.
trust_remote_code
,
)
)
tokenizer
=
tokenizer_group
.
tokenizer
tokenizer
=
tokenizer_group
.
tokenizer
...
@@ -837,7 +848,7 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
...
@@ -837,7 +848,7 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
tokenizer
,
tokenizer
,
chat_template
=
None
,
chat_template
=
None
,
tools
=
tools
,
tools
=
tools
,
trust_remote_code
=
True
,
model_config
=
model_config
,
)
)
assert
isinstance
(
chat_template
,
str
)
assert
isinstance
(
chat_template
,
str
)
...
@@ -857,15 +868,23 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
...
@@ -857,15 +868,23 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
)
)
# yapf: enable
# yapf: enable
def
test_resolve_content_format_hf_defined
(
model
,
expected_format
):
def
test_resolve_content_format_hf_defined
(
model
,
expected_format
):
if
model
==
QWEN25VL_MODEL_ID
and
Version
(
TRANSFORMERS_VERSION
)
<
Version
(
model_info
=
HF_EXAMPLE_MODELS
.
find_hf_info
(
model
)
"4.49.0"
):
model_info
.
check_available_online
(
on_fail
=
"skip"
)
pytest
.
skip
(
"Qwen2.5-VL requires transformers>=4.49.0"
)
model_config
=
ModelConfig
(
model
,
tokenizer
=
model_info
.
tokenizer
or
model
,
tokenizer_mode
=
model_info
.
tokenizer_mode
,
trust_remote_code
=
model_info
.
trust_remote_code
,
hf_overrides
=
model_info
.
hf_overrides
,
)
tokenizer_group
=
TokenizerGroup
(
tokenizer_group
=
TokenizerGroup
(
model
,
model
,
enable_lora
=
False
,
enable_lora
=
False
,
max_num_seqs
=
5
,
max_num_seqs
=
5
,
max_input_length
=
None
,
max_input_length
=
None
,
trust_remote_code
=
model_config
.
trust_remote_code
,
)
)
tokenizer
=
tokenizer_group
.
tokenizer
tokenizer
=
tokenizer_group
.
tokenizer
...
@@ -874,7 +893,7 @@ def test_resolve_content_format_hf_defined(model, expected_format):
...
@@ -874,7 +893,7 @@ def test_resolve_content_format_hf_defined(model, expected_format):
tokenizer
,
tokenizer
,
chat_template
=
None
,
chat_template
=
None
,
tools
=
None
,
tools
=
None
,
trust_remote_code
=
True
,
model_config
=
model_config
,
)
)
assert
isinstance
(
chat_template
,
str
)
assert
isinstance
(
chat_template
,
str
)
...
@@ -888,7 +907,66 @@ def test_resolve_content_format_hf_defined(model, expected_format):
...
@@ -888,7 +907,66 @@ def test_resolve_content_format_hf_defined(model, expected_format):
None
,
None
,
"auto"
,
"auto"
,
tokenizer
,
tokenizer
,
trust_remote_code
=
True
,
model_config
=
model_config
,
)
assert
resolved_format
==
expected_format
# yapf: disable
@
pytest
.
mark
.
parametrize
(
(
"model"
,
"expected_format"
),
[(
"Salesforce/blip2-opt-2.7b"
,
"string"
),
(
"facebook/chameleon-7b"
,
"string"
),
(
"deepseek-ai/deepseek-vl2-tiny"
,
"string"
),
(
"microsoft/Florence-2-base"
,
"string"
),
(
"adept/fuyu-8b"
,
"string"
),
(
"google/paligemma-3b-mix-224"
,
"string"
),
(
"Qwen/Qwen-VL"
,
"string"
),
(
"Qwen/Qwen-VL-Chat"
,
"string"
)],
)
# yapf: enable
def
test_resolve_content_format_fallbacks
(
model
,
expected_format
):
model_info
=
HF_EXAMPLE_MODELS
.
find_hf_info
(
model
)
model_info
.
check_available_online
(
on_fail
=
"skip"
)
model_config
=
ModelConfig
(
model
,
tokenizer
=
model_info
.
tokenizer
or
model
,
tokenizer_mode
=
model_info
.
tokenizer_mode
,
trust_remote_code
=
model_info
.
trust_remote_code
,
hf_overrides
=
model_info
.
hf_overrides
,
)
tokenizer_group
=
TokenizerGroup
(
model_config
.
tokenizer
,
enable_lora
=
False
,
max_num_seqs
=
5
,
max_input_length
=
None
,
trust_remote_code
=
model_config
.
trust_remote_code
,
)
tokenizer
=
tokenizer_group
.
tokenizer
# Test detecting the tokenizer's chat_template
chat_template
=
resolve_hf_chat_template
(
tokenizer
,
chat_template
=
None
,
tools
=
None
,
model_config
=
model_config
,
)
assert
isinstance
(
chat_template
,
str
)
print
(
"[TEXT]"
)
print
(
chat_template
)
print
(
"[AST]"
)
print
(
_try_extract_ast
(
chat_template
))
resolved_format
=
resolve_chat_template_content_format
(
None
,
# Test detecting the tokenizer's chat_template
None
,
"auto"
,
tokenizer
,
model_config
=
model_config
,
)
)
assert
resolved_format
==
expected_format
assert
resolved_format
==
expected_format
...
@@ -899,17 +977,13 @@ def test_resolve_content_format_hf_defined(model, expected_format):
...
@@ -899,17 +977,13 @@ def test_resolve_content_format_hf_defined(model, expected_format):
(
"template_path"
,
"expected_format"
),
(
"template_path"
,
"expected_format"
),
[(
"template_alpaca.jinja"
,
"string"
),
[(
"template_alpaca.jinja"
,
"string"
),
(
"template_baichuan.jinja"
,
"string"
),
(
"template_baichuan.jinja"
,
"string"
),
(
"template_blip2.jinja"
,
"string"
),
(
"template_chatglm.jinja"
,
"string"
),
(
"template_chatglm.jinja"
,
"string"
),
(
"template_chatglm2.jinja"
,
"string"
),
(
"template_chatglm2.jinja"
,
"string"
),
(
"template_chatml.jinja"
,
"string"
),
(
"template_chatml.jinja"
,
"string"
),
(
"template_deepseek_vl2.jinja"
,
"string"
),
(
"template_dse_qwen2_vl.jinja"
,
"openai"
),
(
"template_dse_qwen2_vl.jinja"
,
"openai"
),
(
"template_falcon_180b.jinja"
,
"string"
),
(
"template_falcon_180b.jinja"
,
"string"
),
(
"template_falcon.jinja"
,
"string"
),
(
"template_falcon.jinja"
,
"string"
),
(
"template_florence2.jinja"
,
"string"
),
(
"template_inkbot.jinja"
,
"string"
),
(
"template_inkbot.jinja"
,
"string"
),
(
"template_llava.jinja"
,
"string"
),
(
"template_teleflm.jinja"
,
"string"
),
(
"template_teleflm.jinja"
,
"string"
),
(
"template_vlm2vec.jinja"
,
"openai"
),
(
"template_vlm2vec.jinja"
,
"openai"
),
(
"tool_chat_template_granite_20b_fc.jinja"
,
"string"
),
(
"tool_chat_template_granite_20b_fc.jinja"
,
"string"
),
...
@@ -922,11 +996,18 @@ def test_resolve_content_format_hf_defined(model, expected_format):
...
@@ -922,11 +996,18 @@ def test_resolve_content_format_hf_defined(model, expected_format):
)
)
# yapf: enable
# yapf: enable
def
test_resolve_content_format_examples
(
template_path
,
expected_format
):
def
test_resolve_content_format_examples
(
template_path
,
expected_format
):
model_config
=
ModelConfig
(
PHI3V_MODEL_ID
,
# Dummy
tokenizer
=
PHI3V_MODEL_ID
,
# Dummy
trust_remote_code
=
True
,
)
tokenizer_group
=
TokenizerGroup
(
tokenizer_group
=
TokenizerGroup
(
PHI3V_MODEL_ID
,
PHI3V_MODEL_ID
,
# Dummy
enable_lora
=
False
,
enable_lora
=
False
,
max_num_seqs
=
5
,
max_num_seqs
=
5
,
max_input_length
=
None
,
max_input_length
=
None
,
trust_remote_code
=
model_config
.
trust_remote_code
,
)
)
dummy_tokenizer
=
tokenizer_group
.
tokenizer
dummy_tokenizer
=
tokenizer_group
.
tokenizer
dummy_tokenizer
.
chat_template
=
None
dummy_tokenizer
.
chat_template
=
None
...
@@ -944,7 +1025,7 @@ def test_resolve_content_format_examples(template_path, expected_format):
...
@@ -944,7 +1025,7 @@ def test_resolve_content_format_examples(template_path, expected_format):
None
,
None
,
"auto"
,
"auto"
,
dummy_tokenizer
,
dummy_tokenizer
,
trust_remote_code
=
True
,
model_config
=
model_config
,
)
)
assert
resolved_format
==
expected_format
assert
resolved_format
==
expected_format
tests/kernels/attention/test_attention_selector.py
View file @
7a985548
...
@@ -102,7 +102,10 @@ def test_env(
...
@@ -102,7 +102,10 @@ def test_env(
block_size
,
block_size
,
False
,
False
,
use_mla
=
use_mla
)
use_mla
=
use_mla
)
assert
backend
.
get_name
()
==
name
if
use_v1
and
name
!=
"TRITON_MLA"
:
assert
backend
.
get_name
()
==
f
"
{
name
}
_VLLM_V1"
else
:
assert
backend
.
get_name
()
==
name
else
:
else
:
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
get_attn_backend
(
16
,
get_attn_backend
(
16
,
...
@@ -185,8 +188,9 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
...
@@ -185,8 +188,9 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
STR_FLASH_ATTN_VAL
)
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
STR_FLASH_ATTN_VAL
)
# Unsupported CUDA arch
# Unsupported CUDA arch
monkeypatch
.
setattr
(
torch
.
cuda
,
"get_device_capability"
,
lambda
:
monkeypatch
.
setattr
(
torch
.
cuda
,
(
7
,
5
))
"get_device_capability"
,
lambda
_
=
None
:
(
7
,
5
))
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
16
,
False
)
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
...
...
tests/kernels/attention/test_flashmla.py
View file @
7a985548
...
@@ -5,11 +5,11 @@ import random
...
@@ -5,11 +5,11 @@ import random
import
pytest
import
pytest
import
torch
import
torch
import
triton
from
vllm.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
from
vllm.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
get_mla_metadata
,
get_mla_metadata
,
is_flashmla_supported
)
is_flashmla_supported
)
from
vllm.triton_utils
import
triton
def
cal_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
name
:
str
)
->
None
:
def
cal_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
name
:
str
)
->
None
:
...
...
tests/kernels/attention/test_rocm_attention_selector.py
View file @
7a985548
...
@@ -48,7 +48,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
...
@@ -48,7 +48,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
"ROCM_AITER_MLA"
)
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
"ROCM_AITER_MLA"
)
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
1
,
False
,
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
1
,
False
,
False
,
True
)
False
,
True
)
assert
backend
.
get_name
()
==
"ROCM_AITER_MLA"
assert
(
backend
.
get_name
()
==
"ROCM_AITER_MLA"
or
backend
.
get_name
()
==
"ROCM_AITER_MLA_VLLM_V1"
)
# If attention backend is None
# If attention backend is None
# If use_mla is true
# If use_mla is true
...
@@ -58,4 +59,5 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
...
@@ -58,4 +59,5 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
m
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
m
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
1
,
False
,
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
1
,
False
,
False
,
True
)
False
,
True
)
assert
backend
.
get_name
()
==
"ROCM_AITER_MLA"
assert
(
backend
.
get_name
()
==
"ROCM_AITER_MLA"
or
backend
.
get_name
()
==
"ROCM_AITER_MLA_VLLM_V1"
)
tests/kernels/attention/test_triton_unified_attention.py
0 → 100644
View file @
7a985548
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
import
pytest
import
torch
from
vllm.attention.ops.triton_unified_attention
import
unified_attention
from
vllm.platforms
import
current_platform
NUM_HEADS
=
[(
4
,
4
),
(
8
,
2
),
(
16
,
2
)]
HEAD_SIZES
=
[
128
,
256
]
BLOCK_SIZES
=
[
16
,
32
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
QDTYPES
=
[
None
,
torch
.
float8_e4m3fn
]
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS
=
[
32768
,
2048
]
def
ref_paged_attn
(
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
query_lens
:
list
[
int
],
kv_lens
:
list
[
int
],
block_tables
:
torch
.
Tensor
,
scale
:
float
,
sliding_window
:
Optional
[
int
]
=
None
,
soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
num_seqs
=
len
(
query_lens
)
block_tables
=
block_tables
.
cpu
().
numpy
()
_
,
block_size
,
num_kv_heads
,
head_size
=
key_cache
.
shape
outputs
:
list
[
torch
.
Tensor
]
=
[]
start_idx
=
0
for
i
in
range
(
num_seqs
):
query_len
=
query_lens
[
i
]
kv_len
=
kv_lens
[
i
]
q
=
query
[
start_idx
:
start_idx
+
query_len
]
q
*=
scale
num_kv_blocks
=
(
kv_len
+
block_size
-
1
)
//
block_size
block_indices
=
block_tables
[
i
,
:
num_kv_blocks
]
k
=
key_cache
[
block_indices
].
view
(
-
1
,
num_kv_heads
,
head_size
)
k
=
k
[:
kv_len
]
v
=
value_cache
[
block_indices
].
view
(
-
1
,
num_kv_heads
,
head_size
)
v
=
v
[:
kv_len
]
if
q
.
shape
[
1
]
!=
k
.
shape
[
1
]:
k
=
torch
.
repeat_interleave
(
k
,
q
.
shape
[
1
]
//
k
.
shape
[
1
],
dim
=
1
)
v
=
torch
.
repeat_interleave
(
v
,
q
.
shape
[
1
]
//
v
.
shape
[
1
],
dim
=
1
)
attn
=
torch
.
einsum
(
"qhd,khd->hqk"
,
q
,
k
).
float
()
empty_mask
=
torch
.
ones
(
query_len
,
kv_len
)
mask
=
torch
.
triu
(
empty_mask
,
diagonal
=
kv_len
-
query_len
+
1
).
bool
()
if
sliding_window
is
not
None
:
sliding_window_mask
=
torch
.
triu
(
empty_mask
,
diagonal
=
kv_len
-
(
query_len
+
sliding_window
)
+
1
).
bool
().
logical_not
()
mask
|=
sliding_window_mask
if
soft_cap
is
not
None
and
soft_cap
>
0
:
attn
=
soft_cap
*
torch
.
tanh
(
attn
/
soft_cap
)
attn
.
masked_fill_
(
mask
,
float
(
"-inf"
))
attn
=
torch
.
softmax
(
attn
,
dim
=-
1
).
to
(
v
.
dtype
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn
,
v
)
outputs
.
append
(
out
)
start_idx
+=
query_len
return
torch
.
cat
(
outputs
,
dim
=
0
)
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
[[(
1
,
1328
),
(
5
,
18
),
(
129
,
463
)],
[(
1
,
523
),
(
1
,
37
),
(
1
,
2011
)]])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
,
256
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
QDTYPES
)
@
torch
.
inference_mode
()
def
test_triton_unified_attn
(
seq_lens
:
list
[
tuple
[
int
,
int
]],
num_heads
:
tuple
[
int
,
int
],
head_size
:
int
,
sliding_window
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
q_dtype
:
Optional
[
torch
.
dtype
],
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
if
q_dtype
is
not
None
and
q_dtype
.
itemsize
<
2
and
block_size
<
32
:
pytest
.
skip
(
"block size must be at least 32 for fp8"
)
current_platform
.
seed_everything
(
0
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
max_query_len
=
max
(
query_lens
)
max_kv_len
=
max
(
kv_lens
)
window_size
=
((
sliding_window
-
1
,
0
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
scale
=
head_size
**-
0.5
query
=
torch
.
randn
(
sum
(
query_lens
),
num_query_heads
,
head_size
,
dtype
=
dtype
)
key_cache
=
torch
.
randn
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
value_cache
=
torch
.
randn_like
(
key_cache
)
cu_query_lens
=
torch
.
tensor
([
0
]
+
query_lens
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
kv_lens
=
torch
.
tensor
(
kv_lens
,
dtype
=
torch
.
int32
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
num_blocks
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
output
=
torch
.
empty_like
(
query
)
maybe_quantized_query
=
query
maybe_quantized_key_cache
=
key_cache
maybe_quantized_value_cache
=
value_cache
q_descale
=
None
k_descale
=
None
v_descale
=
None
if
q_dtype
is
not
None
:
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
maybe_quantized_query
=
query
.
to
(
q_dtype
)
maybe_quantized_key_cache
=
key_cache
.
to
(
q_dtype
)
maybe_quantized_value_cache
=
value_cache
.
to
(
q_dtype
)
scale_shape
=
(
num_seqs
,
num_kv_heads
)
q_descale
=
None
# Not yet supported
k_descale
=
torch
.
rand
(
scale_shape
,
dtype
=
torch
.
float32
)
v_descale
=
torch
.
rand
(
scale_shape
,
dtype
=
torch
.
float32
)
unified_attention
(
q
=
maybe_quantized_query
,
k
=
maybe_quantized_key_cache
,
v
=
maybe_quantized_value_cache
,
out
=
output
,
cu_seqlens_q
=
cu_query_lens
,
seqused_k
=
kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
,
softmax_scale
=
scale
,
causal
=
True
,
window_size
=
window_size
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
query_lens
=
query_lens
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
scale
=
scale
,
sliding_window
=
sliding_window
,
soft_cap
=
soft_cap
,
)
atol
,
rtol
=
1.5e-2
,
1e-2
if
q_dtype
is
not
None
:
atol
,
rtol
=
1.5e-1
,
1.5e-1
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
tests/kernels/core/test_pos_encoding.py
View file @
7a985548
...
@@ -21,6 +21,7 @@ SEEDS = [0]
...
@@ -21,6 +21,7 @@ SEEDS = [0]
CUDA_DEVICES
=
[
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
]
USE_KEY
=
[
True
,
False
]
def
_get_flat_tensor_shape
(
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
def
_get_flat_tensor_shape
(
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
...
@@ -28,12 +29,20 @@ def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
...
@@ -28,12 +29,20 @@ def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
return
(
batch_size
,
seq_len
,
num_heads
*
head_size
)
return
(
batch_size
,
seq_len
,
num_heads
*
head_size
)
# For testing sliced tensors
def
_get_padded_tensor_shape
(
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
head_size
:
int
)
->
tuple
[
int
,
...]:
return
(
batch_size
,
seq_len
,
num_heads
,
head_size
+
64
)
def
_get_batch_tensor_shape
(
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
def
_get_batch_tensor_shape
(
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
head_size
:
int
)
->
tuple
[
int
,
...]:
head_size
:
int
)
->
tuple
[
int
,
...]:
return
(
batch_size
,
seq_len
,
num_heads
,
head_size
)
return
(
batch_size
,
seq_len
,
num_heads
,
head_size
)
TENSORS_SHAPES_FN
=
[
_get_batch_tensor_shape
,
_get_flat_tensor_shape
]
TENSORS_SHAPES_FN
=
[
_get_batch_tensor_shape
,
_get_flat_tensor_shape
,
_get_padded_tensor_shape
]
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
...
@@ -46,6 +55,7 @@ TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]
...
@@ -46,6 +55,7 @@ TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"use_key"
,
USE_KEY
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_rotary_embedding
(
def
test_rotary_embedding
(
is_neox_style
:
bool
,
is_neox_style
:
bool
,
...
@@ -58,6 +68,7 @@ def test_rotary_embedding(
...
@@ -58,6 +68,7 @@ def test_rotary_embedding(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
use_key
:
bool
,
max_position
:
int
=
8192
,
max_position
:
int
=
8192
,
base
:
int
=
10000
,
base
:
int
=
10000
,
)
->
None
:
)
->
None
:
...
@@ -74,7 +85,11 @@ def test_rotary_embedding(
...
@@ -74,7 +85,11 @@ def test_rotary_embedding(
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
query_shape
=
tensor_shape_fn
(
batch_size
,
seq_len
,
num_heads
,
head_size
)
query_shape
=
tensor_shape_fn
(
batch_size
,
seq_len
,
num_heads
,
head_size
)
query
=
torch
.
randn
(
query_shape
,
dtype
=
dtype
)
query
=
torch
.
randn
(
query_shape
,
dtype
=
dtype
)
key
=
torch
.
randn_like
(
query
)
key
=
torch
.
randn_like
(
query
)
if
use_key
else
None
# slice tensor if required, noop otherwise
query
=
query
[...,
:
head_size
]
key
=
key
[...,
:
head_size
]
if
use_key
else
None
# NOTE(woosuk): The reference implementation should be executed first
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
# because the custom kernel is in-place.
...
@@ -85,10 +100,14 @@ def test_rotary_embedding(
...
@@ -85,10 +100,14 @@ def test_rotary_embedding(
ref_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
rtol
=
get_default_rtol
(
out_query
))
torch
.
testing
.
assert_close
(
out_key
,
if
use_key
:
ref_key
,
torch
.
testing
.
assert_close
(
out_key
,
atol
=
get_default_atol
(
out_key
),
ref_key
,
rtol
=
get_default_rtol
(
out_key
))
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
else
:
assert
ref_key
is
None
and
out_key
is
None
,
\
"expected returned key to be None"
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
...
@@ -101,6 +120,7 @@ def test_rotary_embedding(
...
@@ -101,6 +120,7 @@ def test_rotary_embedding(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"use_key"
,
USE_KEY
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_batched_rotary_embedding
(
def
test_batched_rotary_embedding
(
is_neox_style
:
bool
,
is_neox_style
:
bool
,
...
@@ -113,6 +133,7 @@ def test_batched_rotary_embedding(
...
@@ -113,6 +133,7 @@ def test_batched_rotary_embedding(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
use_key
:
bool
,
max_position
:
int
=
8192
,
max_position
:
int
=
8192
,
base
:
int
=
10000
,
base
:
int
=
10000
,
)
->
None
:
)
->
None
:
...
@@ -129,7 +150,11 @@ def test_batched_rotary_embedding(
...
@@ -129,7 +150,11 @@ def test_batched_rotary_embedding(
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
query_shape
=
tensor_shape_fn
(
batch_size
,
seq_len
,
num_heads
,
head_size
)
query_shape
=
tensor_shape_fn
(
batch_size
,
seq_len
,
num_heads
,
head_size
)
query
=
torch
.
randn
(
query_shape
,
dtype
=
dtype
)
query
=
torch
.
randn
(
query_shape
,
dtype
=
dtype
)
key
=
torch
.
randn_like
(
query
)
key
=
torch
.
randn_like
(
query
)
if
use_key
else
None
# slice tensor if required, noop otherwise
query
=
query
[...,
:
head_size
]
key
=
key
[...,
:
head_size
]
if
use_key
else
None
# NOTE(woosuk): The reference implementation should be executed first
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
# because the custom kernel is in-place.
...
@@ -145,10 +170,14 @@ def test_batched_rotary_embedding(
...
@@ -145,10 +170,14 @@ def test_batched_rotary_embedding(
ref_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
rtol
=
get_default_rtol
(
out_query
))
torch
.
testing
.
assert_close
(
out_key
,
if
use_key
:
ref_key
,
torch
.
testing
.
assert_close
(
out_key
,
atol
=
get_default_atol
(
out_key
),
ref_key
,
rtol
=
get_default_rtol
(
out_key
))
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
else
:
assert
ref_key
is
None
and
out_key
is
None
,
\
"expected returned key to be None"
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
...
@@ -160,6 +189,7 @@ def test_batched_rotary_embedding(
...
@@ -160,6 +189,7 @@ def test_batched_rotary_embedding(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"use_key"
,
USE_KEY
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_batched_rotary_embedding_multi_lora
(
def
test_batched_rotary_embedding_multi_lora
(
is_neox_style
:
bool
,
is_neox_style
:
bool
,
...
@@ -171,6 +201,7 @@ def test_batched_rotary_embedding_multi_lora(
...
@@ -171,6 +201,7 @@ def test_batched_rotary_embedding_multi_lora(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
use_key
:
bool
,
max_position
:
int
=
8192
,
max_position
:
int
=
8192
,
base
:
int
=
10000
,
base
:
int
=
10000
,
)
->
None
:
)
->
None
:
...
@@ -190,7 +221,7 @@ def test_batched_rotary_embedding_multi_lora(
...
@@ -190,7 +221,7 @@ def test_batched_rotary_embedding_multi_lora(
seq_len
,
seq_len
,
num_heads
*
head_size
,
num_heads
*
head_size
,
dtype
=
dtype
)
dtype
=
dtype
)
key
=
torch
.
randn_like
(
query
)
key
=
torch
.
randn_like
(
query
)
if
use_key
else
None
offset_map
=
torch
.
tensor
(
offset_map
=
torch
.
tensor
(
list
(
list
(
...
@@ -214,10 +245,14 @@ def test_batched_rotary_embedding_multi_lora(
...
@@ -214,10 +245,14 @@ def test_batched_rotary_embedding_multi_lora(
ref_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
rtol
=
get_default_rtol
(
out_query
))
torch
.
testing
.
assert_close
(
out_key
,
if
use_key
:
ref_key
,
torch
.
testing
.
assert_close
(
out_key
,
atol
=
get_default_atol
(
out_key
),
ref_key
,
rtol
=
get_default_rtol
(
out_key
))
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
else
:
assert
ref_key
is
None
and
out_key
is
None
,
\
"expected returned key to be None"
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
...
tests/kernels/core/test_rotary_embedding.py
View file @
7a985548
...
@@ -15,7 +15,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
...
@@ -15,7 +15,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
def
rotary_embedding_opcheck
(
rot
,
def
rotary_embedding_opcheck
(
rot
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
):
offsets
:
Optional
[
torch
.
Tensor
]
=
None
):
cos_sin_cache
=
rot
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
cos_sin_cache
=
rot
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
...
@@ -37,9 +37,11 @@ def rotary_embedding_opcheck(rot,
...
@@ -37,9 +37,11 @@ def rotary_embedding_opcheck(rot,
@
pytest
.
mark
.
parametrize
(
"rotary_dim"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"rotary_dim"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
32
,
108
])
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
32
,
108
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
11
,
1024
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
11
,
1024
])
@
pytest
.
mark
.
parametrize
(
"use_key"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"head_stride_is_contingous"
,
[
True
,
False
])
def
test_rotary_embedding_opcheck
(
dist_init
,
device
,
max_position
,
def
test_rotary_embedding_opcheck
(
dist_init
,
device
,
max_position
,
is_neox_style
,
rotary_dim
,
head_size
,
is_neox_style
,
rotary_dim
,
head_size
,
seq_len
):
seq_len
,
use_key
,
head_stride_is_contingous
):
batch_size
=
1
batch_size
=
1
base
=
10000
base
=
10000
num_heads
=
7
num_heads
=
7
...
@@ -49,15 +51,27 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
...
@@ -49,15 +51,27 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
positions
=
torch
.
randint
(
0
,
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
),
max_position
,
(
batch_size
,
seq_len
),
device
=
device
)
device
=
device
)
head_stride
=
head_size
+
(
64
if
head_stride_is_contingous
else
0
)
query
=
torch
.
randn
(
batch_size
,
query
=
torch
.
randn
(
batch_size
,
seq_len
,
seq_len
,
num_heads
*
head_size
,
num_heads
,
head_stride
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
device
)
device
=
device
)
key
=
torch
.
randn_like
(
query
)
key
=
torch
.
randn_like
(
query
)
if
use_key
else
None
query
=
query
[...,
:
head_size
]
key
=
key
[...,
:
head_size
]
if
use_key
else
None
rotary_embedding_opcheck
(
rot
,
positions
,
query
,
key
)
rotary_embedding_opcheck
(
rot
,
positions
,
query
,
key
)
offsets
=
torch
.
zeros
(
batch_size
*
seq_len
,
offsets
=
torch
.
zeros
(
batch_size
*
seq_len
,
device
=
device
,
device
=
device
,
dtype
=
torch
.
long
)
dtype
=
torch
.
long
)
rotary_embedding_opcheck
(
rot
,
positions
,
query
,
key
,
offsets
)
rotary_embedding_opcheck
(
rot
,
positions
,
query
,
key
,
offsets
)
# if we have a contiguous head stride, test the alternate
# [..., num_heads * head_dim] shape/layout
if
head_stride_is_contingous
:
rotary_embedding_opcheck
(
rot
,
positions
,
query
.
flatten
(
start_dim
=-
2
),
key
.
flatten
(
start_dim
=-
2
)
if
use_key
else
None
)
tests/kernels/mamba/test_mamba_ssm_ssd.py
View file @
7a985548
...
@@ -6,7 +6,7 @@ import torch.nn.functional as F
...
@@ -6,7 +6,7 @@ import torch.nn.functional as F
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
(
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
(
_
seq_idx
_to_chunk_indices_offsets
)
_
query_start_loc
_to_chunk_indices_offsets
)
from
vllm.model_executor.layers.mamba.ops.ssd_combined
import
(
from
vllm.model_executor.layers.mamba.ops.ssd_combined
import
(
mamba_chunk_scan_combined
)
mamba_chunk_scan_combined
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -274,8 +274,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
...
@@ -274,8 +274,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
last_taken
,
exhausted
,
n_heads
,
last_taken
,
exhausted
,
n_heads
,
d_head
,
itype
):
d_head
,
itype
):
chunk_indices
,
chunk_offsets
=
_seq_idx_to_chunk_indices_offsets
(
chunk_indices
,
chunk_offsets
=
\
seq_idx
,
chunk_size
)
_query_start_loc_to_chunk_indices_offsets
(
cu_seqlens
,
chunk_size
,
cu_seqlens
[
-
1
])
Y
,
new_states
=
mamba_chunk_scan_combined
(
Y
,
new_states
=
mamba_chunk_scan_combined
(
X
,
X
,
...
...
tests/kernels/moe/test_batched_moe.py
0 → 100644
View file @
7a985548
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
import
pytest
import
torch
import
triton.language
as
tl
from
vllm.model_executor.layers.fused_moe.fused_batched_moe
import
(
invoke_moe_batched_triton_kernel
)
@
dataclass
class
BatchedMMConfig
:
dtype
:
torch
.
dtype
num_experts
:
int
max_tokens_per_expert
:
int
K
:
int
N
:
int
@
dataclass
class
BatchedMMTensors
:
A
:
torch
.
Tensor
# [E, max_tokens, K]
B
:
torch
.
Tensor
# [E, K, N] - column major
C
:
torch
.
Tensor
# [E, max_tokens, N]
num_expert_tokens
:
torch
.
Tensor
# [E]
@
staticmethod
def
make_tensors
(
config
:
BatchedMMConfig
):
A
=
torch
.
randn
(
(
config
.
num_experts
,
config
.
max_tokens_per_expert
,
config
.
K
),
device
=
"cuda"
,
dtype
=
config
.
dtype
)
/
10
B
=
torch
.
randn
((
config
.
num_experts
,
config
.
N
,
config
.
K
),
device
=
"cuda"
,
dtype
=
config
.
dtype
)
C
=
torch
.
zeros
(
(
config
.
num_experts
,
config
.
max_tokens_per_expert
,
config
.
N
),
device
=
"cuda"
,
dtype
=
config
.
dtype
)
num_expert_tokens
=
torch
.
randint
(
low
=
0
,
high
=
config
.
max_tokens_per_expert
,
size
=
(
config
.
num_experts
,
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
return
BatchedMMTensors
(
A
,
B
,
C
,
num_expert_tokens
)
def
ref_impl
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
num_expert_tokens
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_expert_tokens_cpu
=
num_expert_tokens
.
clone
()
num_expert_tokens_cpu
=
num_expert_tokens_cpu
.
to
(
device
=
"cpu"
)
num_experts
=
num_expert_tokens
.
size
(
0
)
for
e
in
range
(
num_experts
):
num_tokens
=
num_expert_tokens_cpu
[
e
]
C
[
e
,
:
num_tokens
,
:]
=
A
[
e
,
:
num_tokens
,
:]
@
B
[
e
].
transpose
(
0
,
1
)
return
C
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
16
,
32
])
@
pytest
.
mark
.
parametrize
(
"max_tokens_per_expert"
,
[
32
,
64
,
128
,
192
,
224
,
256
,
512
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
128
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"N"
,
[
128
,
256
,
512
,
1024
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_batched_mm
(
num_experts
:
int
,
max_tokens_per_expert
:
int
,
K
:
int
,
N
:
int
,
dtype
:
torch
.
dtype
):
config
=
BatchedMMConfig
(
dtype
,
num_experts
,
max_tokens_per_expert
,
K
,
N
)
tensors
=
BatchedMMTensors
.
make_tensors
(
config
)
test_output
=
tensors
.
C
ref_output
=
test_output
.
clone
()
compute_tl_dtype
=
{
torch
.
float16
:
tl
.
float16
,
torch
.
bfloat16
:
tl
.
bfloat16
,
torch
.
float32
:
tl
.
float32
}[
test_output
.
dtype
]
invoke_moe_batched_triton_kernel
(
tensors
.
A
,
tensors
.
B
,
test_output
,
tensors
.
num_expert_tokens
,
compute_tl_dtype
,
# Quantization data
None
,
None
,
None
,
# Quantization schemes
False
,
False
,
False
,
config
=
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
16
})
ref_output
=
ref_impl
(
tensors
.
A
,
tensors
.
B
,
ref_output
,
tensors
.
num_expert_tokens
)
rtol
,
atol
=
{
torch
.
float16
:
(
6e-2
,
6e-2
),
torch
.
bfloat16
:
(
6e-2
,
6e-2
),
torch
.
float32
:
(
1e-2
,
1e-2
),
}[
test_output
.
dtype
]
torch
.
testing
.
assert_close
(
test_output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
tests/kernels/moe/test_cutlass_moe.py
View file @
7a985548
...
@@ -30,6 +30,11 @@ MNK_FACTORS = [
...
@@ -30,6 +30,11 @@ MNK_FACTORS = [
(
224
,
3072
,
1536
),
(
224
,
3072
,
1536
),
]
]
vllm_config
=
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
MOETensors
:
class
MOETensors
:
...
@@ -190,7 +195,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
...
@@ -190,7 +195,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'w1_q'
:
moe_tensors
.
w1_q
.
transpose
(
1
,
2
),
# type: ignore[union-attr]
'w1_q'
:
moe_tensors
.
w1_q
.
transpose
(
1
,
2
),
# type: ignore[union-attr]
'w2_q'
:
moe_tensors
.
w2_q
.
transpose
(
1
,
2
),
# type: ignore[union-attr]
'w2_q'
:
moe_tensors
.
w2_q
.
transpose
(
1
,
2
),
# type: ignore[union-attr]
'topk_weights'
:
topk_weights
,
'topk_weights'
:
topk_weights
,
'topk_ids
_
'
:
topk_ids
,
'topk_ids'
:
topk_ids
,
'ab_strides1'
:
moe_tensors
.
ab_strides1
,
'ab_strides1'
:
moe_tensors
.
ab_strides1
,
'c_strides1'
:
moe_tensors
.
c_strides1
,
'c_strides1'
:
moe_tensors
.
c_strides1
,
'ab_strides2'
:
moe_tensors
.
ab_strides2
,
'ab_strides2'
:
moe_tensors
.
ab_strides2
,
...
@@ -231,18 +236,15 @@ def test_cutlass_moe_8_bit_no_graph(
...
@@ -231,18 +236,15 @@ def test_cutlass_moe_8_bit_no_graph(
per_out_ch
:
bool
,
per_out_ch
:
bool
,
):
):
current_platform
.
seed_everything
(
7
)
current_platform
.
seed_everything
(
7
)
with
set_current_vllm_config
(
with
set_current_vllm_config
(
vllm_config
):
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))):
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
per_out_ch
)
per_out_ch
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
torch
.
half
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
torch
.
half
)
topk_weights
,
topk_ids
=
fused_topk
(
mt
.
a
,
topk_weights
,
topk_ids
,
_
=
fused_topk
(
mt
.
a
,
score
,
score
,
topk
,
topk
,
renormalize
=
False
)
renormalize
=
False
)
# Note that we are using the dequantized versions of the tensors.
# Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences.
# Using a, w1 and w2 directly results in minor output differences.
...
@@ -276,20 +278,17 @@ def test_cutlass_moe_8_bit_cuda_graph(
...
@@ -276,20 +278,17 @@ def test_cutlass_moe_8_bit_cuda_graph(
per_out_ch
:
bool
,
per_out_ch
:
bool
,
):
):
current_platform
.
seed_everything
(
7
)
current_platform
.
seed_everything
(
7
)
with
set_current_vllm_config
(
with
set_current_vllm_config
(
vllm_config
):
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))):
dtype
=
torch
.
half
dtype
=
torch
.
half
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
per_out_ch
)
per_out_ch
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
=
fused_topk
(
mt
.
a
,
topk_weights
,
topk_ids
,
_
=
fused_topk
(
mt
.
a
,
score
,
score
,
topk
,
topk
,
renormalize
=
False
)
renormalize
=
False
)
# Note that we are using the dequantized versions of the tensors.
# Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences.
# Using a, w1 and w2 directly results in minor output differences.
...
@@ -334,18 +333,15 @@ def test_cutlass_moe_8_bit_EP(
...
@@ -334,18 +333,15 @@ def test_cutlass_moe_8_bit_EP(
ep_size
:
int
,
ep_size
:
int
,
):
):
current_platform
.
seed_everything
(
7
)
current_platform
.
seed_everything
(
7
)
with
set_current_vllm_config
(
with
set_current_vllm_config
(
vllm_config
):
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))):
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
per_out_channel
)
per_out_channel
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
torch
.
half
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
torch
.
half
)
topk_weights
,
topk_ids
=
fused_topk
(
mt
.
a
,
topk_weights
,
topk_ids
,
_
=
fused_topk
(
mt
.
a
,
score
,
score
,
topk
,
topk
,
renormalize
=
False
)
renormalize
=
False
)
# Note that we are using the dequantized versions of the tensors.
# Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences.
# Using a, w1 and w2 directly results in minor output differences.
...
...
tests/kernels/moe/test_moe.py
View file @
7a985548
...
@@ -11,24 +11,32 @@ from transformers import MixtralConfig
...
@@ -11,24 +11,32 @@ from transformers import MixtralConfig
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
import
vllm.model_executor.layers.fused_moe
# noqa
import
vllm.model_executor.layers.fused_moe
# noqa
from
tests.kernels.utils
import
(
opcheck
,
stack_and_dev
,
torch_moe
,
from
tests.kernels.utils
import
opcheck
,
stack_and_dev
,
torch_moe
torch_moe_single
)
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.moe_torch_iterative
import
(
from
vllm.model_executor.layers.fused_moe.moe_torch_iterative
import
(
fused_moe
as
iterative_moe
)
fused_moe
as
iterative_moe
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
rand_marlin_weight_fp4_like
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
marlin_quant_fp8_torch
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
awq_marlin_quantize
,
marlin_quantize
)
awq_marlin_quantize
,
marlin_quantize
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
quantize_weights
)
quantize_weights
)
from
vllm.model_executor.models.mixtral
import
MixtralMoE
from
vllm.model_executor.models.mixtral
import
MixtralMoE
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
ScalarType
,
scalar_types
NUM_EXPERTS
=
[
8
,
64
]
NUM_EXPERTS
=
[
8
,
64
]
EP_SIZE
=
[
1
,
4
]
EP_SIZE
=
[
1
,
4
]
TOP_KS
=
[
2
,
6
]
TOP_KS
=
[
2
,
6
]
vllm_config
=
VllmConfig
()
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
33
,
64
,
222
,
1024
*
128
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
33
,
64
,
222
,
1024
*
128
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
1024
,
2048
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
1024
,
2048
])
...
@@ -67,31 +75,33 @@ def test_fused_moe(
...
@@ -67,31 +75,33 @@ def test_fused_moe(
else
:
else
:
e_map
=
None
e_map
=
None
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
,
e_map
)
with
set_current_vllm_config
(
vllm_config
):
iterative_output
=
iterative_moe
(
a
,
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
,
e_map
)
w1
,
iterative_output
=
iterative_moe
(
a
,
w2
,
w1
,
score
,
w2
,
topk
,
score
,
global_num_experts
=
e
,
topk
,
expert_map
=
e_map
,
global_num_experts
=
e
,
renormalize
=
False
)
expert_map
=
e_map
,
renormalize
=
False
)
# Pad the weight if moe padding is enabled
if
padding
:
# Pad the weight if moe padding is enabled
w1
=
F
.
pad
(
w1
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
]
if
padding
:
torch
.
cuda
.
empty_cache
()
w1
=
F
.
pad
(
w1
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
]
w2
=
F
.
pad
(
w2
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
]
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
w2
=
F
.
pad
(
w2
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
]
torch
.
cuda
.
empty_cache
()
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
global_num_experts
=
e
,
expert_map
=
e_map
,
renormalize
=
False
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
global_num_experts
=
e
,
expert_map
=
e_map
,
renormalize
=
False
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
iterative_output
,
torch
.
testing
.
assert_close
(
iterative_output
,
torch_output
,
torch_output
,
...
@@ -112,7 +122,6 @@ def test_fused_moe(
...
@@ -112,7 +122,6 @@ def test_fused_moe(
def
test_fused_moe_wn16
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
def
test_fused_moe_wn16
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
ep_size
:
int
,
dtype
:
torch
.
dtype
,
group_size
:
int
,
ep_size
:
int
,
dtype
:
torch
.
dtype
,
group_size
:
int
,
has_zp
:
bool
,
weight_bits
:
int
):
has_zp
:
bool
,
weight_bits
:
int
):
print
(
m
,
n
,
k
,
e
,
topk
,
dtype
,
group_size
,
has_zp
,
weight_bits
)
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
...
@@ -191,22 +200,24 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
...
@@ -191,22 +200,24 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
else
:
else
:
e_map
=
None
e_map
=
None
triton_output
=
fused_moe
(
a
,
with
set_current_vllm_config
(
vllm_config
):
w1_qweight
,
triton_output
=
fused_moe
(
a
,
w2_qweight
,
w1_qweight
,
score
,
w2_qweight
,
topk
,
score
,
renormalize
=
False
,
topk
,
use_int4_w4a16
=
weight_bits
==
4
,
renormalize
=
False
,
use_int8_w8a16
=
weight_bits
==
8
,
use_int4_w4a16
=
weight_bits
==
4
,
global_num_experts
=
e
,
use_int8_w8a16
=
weight_bits
==
8
,
expert_map
=
e_map
,
global_num_experts
=
e
,
w1_scale
=
w1_scales
,
expert_map
=
e_map
,
w2_scale
=
w2_scales
,
w1_scale
=
w1_scales
,
w1_zp
=
w1_qzeros
if
has_zp
else
None
,
w2_scale
=
w2_scales
,
w2_zp
=
w2_qzeros
if
has_zp
else
None
,
w1_zp
=
w1_qzeros
if
has_zp
else
None
,
block_shape
=
[
0
,
group_size
])
w2_zp
=
w2_qzeros
if
has_zp
else
None
,
torch_output
=
torch_moe
(
a
,
w1_ref
,
w2_ref
,
score
,
topk
,
e_map
)
block_shape
=
[
0
,
group_size
])
torch_output
=
torch_moe
(
a
,
w1_ref
,
w2_ref
,
score
,
topk
,
e_map
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
...
@@ -221,9 +232,16 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
...
@@ -221,9 +232,16 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
"""Make sure our Mixtral MoE implementation agrees with the one from
"""Make sure our Mixtral MoE implementation agrees with the one from
huggingface."""
huggingface."""
# clear the cache before every test
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
is_rocm_aiter_moe_enabled
)
is_rocm_aiter_moe_enabled
.
cache_clear
()
if
use_rocm_aiter
:
if
use_rocm_aiter
:
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
if
dtype
==
torch
.
float32
:
pytest
.
skip
(
"AITER ROCm test skip for float32"
)
# Instantiate our and huggingface's MoE blocks
# Instantiate our and huggingface's MoE blocks
config
=
MixtralConfig
()
config
=
MixtralConfig
()
hf_moe
=
MixtralSparseMoeBlock
(
config
).
to
(
dtype
).
to
(
"cuda"
)
hf_moe
=
MixtralSparseMoeBlock
(
config
).
to
(
dtype
).
to
(
"cuda"
)
...
@@ -285,18 +303,64 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
...
@@ -285,18 +303,64 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
atol
=
mixtral_moe_tol
[
dtype
])
atol
=
mixtral_moe_tol
[
dtype
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
33
,
123
])
def
marlin_moe_generate_valid_test_cases
():
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
1024
])
import
itertools
@
pytest
.
mark
.
parametrize
(
"k"
,
[
256
,
2048
])
m_list
=
[
1
,
123
,
666
]
@
pytest
.
mark
.
parametrize
(
"e"
,
[
4
,
12
])
n_list
=
[
128
,
1024
]
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
2
,
3
])
k_list
=
[
256
,
2048
]
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
[
1
,
4
])
e_list
=
[
4
,
12
]
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
topk_list
=
[
2
,
3
]
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
-
1
,
32
,
128
])
ep_size_list
=
[
1
,
4
]
@
pytest
.
mark
.
parametrize
(
"act_order"
,
[
True
,
False
])
dtype_list
=
[
torch
.
half
,
torch
.
bfloat16
]
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
[
4
,
8
])
group_size_list
=
[
-
1
,
16
,
32
,
128
]
@
pytest
.
mark
.
parametrize
(
"has_zp"
,
[
True
,
False
])
act_order_list
=
[
True
,
False
]
@
pytest
.
mark
.
parametrize
(
"is_k_full"
,
[
True
,
False
])
quant_type_list
=
[
scalar_types
.
float4_e2m1f
,
scalar_types
.
float8_e4m3fn
,
scalar_types
.
uint4
,
scalar_types
.
uint4b8
,
scalar_types
.
uint8b128
,
]
is_k_full_list
=
[
True
,
False
]
all_combinations
=
itertools
.
product
(
m_list
,
n_list
,
k_list
,
e_list
,
topk_list
,
ep_size_list
,
dtype_list
,
group_size_list
,
act_order_list
,
quant_type_list
,
is_k_full_list
)
def
is_invalid
(
m
,
n
,
k
,
e
,
topk
,
ep_size
,
dtype
,
group_size
,
act_order
,
quant_type
,
is_k_full
):
if
quant_type
==
scalar_types
.
float8_e4m3fn
and
\
group_size
not
in
[
-
1
,
128
]:
return
False
if
quant_type
==
scalar_types
.
float4_e2m1f
and
group_size
!=
16
:
return
False
if
quant_type
!=
scalar_types
.
float4_e2m1f
and
group_size
==
16
:
return
False
# Filter act_order
if
act_order
:
if
group_size
in
(
-
1
,
k
,
n
):
return
False
if
quant_type
not
in
[
scalar_types
.
uint4b8
]:
return
False
elif
not
is_k_full
:
return
False
return
True
cases
=
[]
for
case
in
all_combinations
:
if
is_invalid
(
*
case
):
cases
.
append
(
case
)
return
cases
@
pytest
.
mark
.
flaky
(
reruns
=
2
)
@
pytest
.
mark
.
parametrize
((
"m, n, k, e, topk, ep_size, dtype, group_size,"
"act_order, quant_type, is_k_full"
),
marlin_moe_generate_valid_test_cases
())
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Skip for rocm"
)
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Skip for rocm"
)
def
test_fused_marlin_moe
(
def
test_fused_marlin_moe
(
m
:
int
,
m
:
int
,
...
@@ -308,14 +372,22 @@ def test_fused_marlin_moe(
...
@@ -308,14 +372,22 @@ def test_fused_marlin_moe(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
group_size
:
int
,
group_size
:
int
,
act_order
:
bool
,
act_order
:
bool
,
num_bits
:
int
,
quant_type
:
ScalarType
,
has_zp
:
bool
,
is_k_full
:
bool
,
is_k_full
:
bool
,
):
):
current_platform
.
seed_everything
(
7
)
torch
.
cuda
.
manual_seed
(
0
)
has_zp
=
quant_type
in
[
scalar_types
.
uint4
,
scalar_types
.
uint8
]
if
quant_type
==
scalar_types
.
float8_e4m3fn
:
if
group_size
not
in
[
-
1
,
128
]:
return
if
act_order
:
return
# Filter act_order
# Filter act_order
if
act_order
:
if
act_order
:
if
quant_type
==
scalar_types
.
float8_e4m3fn
:
return
if
group_size
==
-
1
:
if
group_size
==
-
1
:
return
return
if
group_size
in
(
k
,
n
):
if
group_size
in
(
k
,
n
):
...
@@ -326,17 +398,14 @@ def test_fused_marlin_moe(
...
@@ -326,17 +398,14 @@ def test_fused_marlin_moe(
if
not
is_k_full
:
if
not
is_k_full
:
return
return
if
has_zp
:
if
quant_type
==
scalar_types
.
float4_e2m1f
and
group_size
!=
16
:
# we don't build kernel for int8 with zero
return
if
num_bits
==
8
:
if
quant_type
!=
scalar_types
.
float4_e2m1f
and
group_size
==
16
:
return
return
quant_type
=
scalar_types
.
uint4
if
num_bits
==
4
else
scalar_types
.
uint8
else
:
quant_type
=
scalar_types
.
uint4b8
\
if
num_bits
==
4
else
scalar_types
.
uint8b128
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
1
0
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
2
0
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
1
0
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
2
0
if
ep_size
>
1
:
if
ep_size
>
1
:
local_e
=
e
//
ep_size
local_e
=
e
//
ep_size
...
@@ -351,12 +420,27 @@ def test_fused_marlin_moe(
...
@@ -351,12 +420,27 @@ def test_fused_marlin_moe(
w_ref1_l
=
[]
w_ref1_l
=
[]
qweight1_l
=
[]
qweight1_l
=
[]
scales1_l
=
[]
scales1_l
=
[]
global_scale1_l
=
[]
zeros1_l
=
[]
zeros1_l
=
[]
g_idx1_l
=
[]
g_idx1_l
=
[]
sort_indices1_l
=
[]
sort_indices1_l
=
[]
for
i
in
range
(
w1
.
shape
[
0
]):
for
i
in
range
(
w1
.
shape
[
0
]):
if
has_zp
:
if
quant_type
==
scalar_types
.
float4_e2m1f
:
w_ref1
,
qweight1
,
scales1
,
global_scale1
=
\
rand_marlin_weight_fp4_like
(
w1
[
i
],
group_size
)
w_ref1_l
.
append
(
w_ref1
.
T
)
qweight1_l
.
append
(
qweight1
)
scales1_l
.
append
(
scales1
)
global_scale1_l
.
append
(
global_scale1
)
elif
quant_type
==
scalar_types
.
float8_e4m3fn
:
w_ref1
,
qweight1
,
scales1
=
marlin_quant_fp8_torch
(
w1
[
i
],
group_size
)
w_ref1_l
.
append
(
w_ref1
.
T
)
qweight1_l
.
append
(
qweight1
)
scales1_l
.
append
(
scales1
)
elif
has_zp
:
w_ref1
,
qweight1
,
scales1
,
zeros1
=
awq_marlin_quantize
(
w_ref1
,
qweight1
,
scales1
,
zeros1
=
awq_marlin_quantize
(
w1
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
)
w1
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
)
...
@@ -366,9 +450,9 @@ def test_fused_marlin_moe(
...
@@ -366,9 +450,9 @@ def test_fused_marlin_moe(
zeros1_l
.
append
(
zeros1
)
zeros1_l
.
append
(
zeros1
)
else
:
else
:
test_perm
=
torch
.
randperm
(
k
)
test_perm
=
torch
.
randperm
(
k
)
quant_res
=
marlin_quantize
(
w1
[
i
].
transpose
(
1
,
0
),
quant_type
,
w_ref1
,
qweight1
,
scales1
,
g_idx1
,
sort_indices1
,
_
=
\
group_size
,
act_order
,
test_perm
)
marlin_quantize
(
w1
[
i
].
transpose
(
1
,
0
),
quant_type
,
w_ref1
,
qweight1
,
scales1
,
g_idx1
,
sort_indices1
,
_
=
quant_res
group_size
,
act_order
,
test_perm
)
w_ref1_l
.
append
(
w_ref1
.
T
)
w_ref1_l
.
append
(
w_ref1
.
T
)
qweight1_l
.
append
(
qweight1
)
qweight1_l
.
append
(
qweight1
)
...
@@ -379,6 +463,7 @@ def test_fused_marlin_moe(
...
@@ -379,6 +463,7 @@ def test_fused_marlin_moe(
w_ref1
=
stack_and_dev
(
w_ref1_l
)
w_ref1
=
stack_and_dev
(
w_ref1_l
)
qweight1
=
stack_and_dev
(
qweight1_l
).
contiguous
()
qweight1
=
stack_and_dev
(
qweight1_l
).
contiguous
()
scales1
=
stack_and_dev
(
scales1_l
)
scales1
=
stack_and_dev
(
scales1_l
)
global_scale1
=
stack_and_dev
(
global_scale1_l
)
if
global_scale1_l
else
None
g_idx1
=
stack_and_dev
(
g_idx1_l
)
if
g_idx1_l
else
None
g_idx1
=
stack_and_dev
(
g_idx1_l
)
if
g_idx1_l
else
None
zeros1
=
stack_and_dev
(
zeros1_l
)
if
zeros1_l
else
None
zeros1
=
stack_and_dev
(
zeros1_l
)
if
zeros1_l
else
None
sort_indices1
=
stack_and_dev
(
sort_indices1_l
)
if
sort_indices1_l
else
None
sort_indices1
=
stack_and_dev
(
sort_indices1_l
)
if
sort_indices1_l
else
None
...
@@ -386,12 +471,27 @@ def test_fused_marlin_moe(
...
@@ -386,12 +471,27 @@ def test_fused_marlin_moe(
w_ref2_l
=
[]
w_ref2_l
=
[]
qweight2_l
=
[]
qweight2_l
=
[]
scales2_l
=
[]
scales2_l
=
[]
global_scale2_l
=
[]
zeros2_l
=
[]
zeros2_l
=
[]
g_idx2_l
=
[]
g_idx2_l
=
[]
sort_indices2_l
=
[]
sort_indices2_l
=
[]
for
i
in
range
(
w2
.
shape
[
0
]):
for
i
in
range
(
w2
.
shape
[
0
]):
if
has_zp
:
if
quant_type
==
scalar_types
.
float4_e2m1f
:
w_ref2
,
qweight2
,
scales2
,
global_scale2
=
\
rand_marlin_weight_fp4_like
(
w2
[
i
],
group_size
)
w_ref2_l
.
append
(
w_ref2
.
T
)
qweight2_l
.
append
(
qweight2
)
scales2_l
.
append
(
scales2
)
global_scale2_l
.
append
(
global_scale2
)
elif
quant_type
==
scalar_types
.
float8_e4m3fn
:
w_ref2
,
qweight2
,
scales2
=
marlin_quant_fp8_torch
(
w2
[
i
],
group_size
)
w_ref2_l
.
append
(
w_ref2
.
T
)
qweight2_l
.
append
(
qweight2
)
scales2_l
.
append
(
scales2
)
elif
has_zp
:
w_ref2
,
qweight2
,
scales2
,
zeros2
=
awq_marlin_quantize
(
w_ref2
,
qweight2
,
scales2
,
zeros2
=
awq_marlin_quantize
(
w2
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
)
w2
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
)
...
@@ -401,9 +501,9 @@ def test_fused_marlin_moe(
...
@@ -401,9 +501,9 @@ def test_fused_marlin_moe(
zeros2_l
.
append
(
zeros2
)
zeros2_l
.
append
(
zeros2
)
else
:
else
:
test_perm
=
torch
.
randperm
(
n
)
test_perm
=
torch
.
randperm
(
n
)
quant_res
=
marlin_quantize
(
w2
[
i
].
transpose
(
1
,
0
),
quant_type
,
w_ref2
,
qweight2
,
scales2
,
g_idx2
,
sort_indices2
,
_
=
\
group_size
,
act_order
,
test_perm
)
marlin_quantize
(
w2
[
i
].
transpose
(
1
,
0
),
quant_type
,
w_ref2
,
qweight2
,
scales2
,
g_idx2
,
sort_indices2
,
_
=
quant_res
group_size
,
act_order
,
test_perm
)
w_ref2_l
.
append
(
w_ref2
.
T
)
w_ref2_l
.
append
(
w_ref2
.
T
)
qweight2_l
.
append
(
qweight2
)
qweight2_l
.
append
(
qweight2
)
...
@@ -414,15 +514,17 @@ def test_fused_marlin_moe(
...
@@ -414,15 +514,17 @@ def test_fused_marlin_moe(
w_ref2
=
stack_and_dev
(
w_ref2_l
)
w_ref2
=
stack_and_dev
(
w_ref2_l
)
qweight2
=
stack_and_dev
(
qweight2_l
).
contiguous
()
qweight2
=
stack_and_dev
(
qweight2_l
).
contiguous
()
scales2
=
stack_and_dev
(
scales2_l
)
scales2
=
stack_and_dev
(
scales2_l
)
global_scale2
=
stack_and_dev
(
global_scale2_l
)
if
global_scale2_l
else
None
g_idx2
=
stack_and_dev
(
g_idx2_l
)
if
g_idx2_l
else
None
g_idx2
=
stack_and_dev
(
g_idx2_l
)
if
g_idx2_l
else
None
zeros2
=
stack_and_dev
(
zeros2_l
)
if
zeros2_l
else
None
zeros2
=
stack_and_dev
(
zeros2_l
)
if
zeros2_l
else
None
sort_indices2
=
stack_and_dev
(
sort_indices2_l
)
if
sort_indices2_l
else
None
sort_indices2
=
stack_and_dev
(
sort_indices2_l
)
if
sort_indices2_l
else
None
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
=
fused_topk
(
a
,
score
,
topk
,
False
)
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
torch_output
=
torch_moe
(
a
,
w_ref1
,
w_ref2
,
score
,
topk
,
e_map
)
with
set_current_vllm_config
(
vllm_config
):
torch_output
=
torch_moe
(
a
,
w_ref1
,
w_ref2
,
score
,
topk
,
e_map
)
marlin_output
=
torch
.
ops
.
vllm
.
fused_marlin_moe
(
marlin_output
=
torch
.
ops
.
vllm
.
fused_marlin_moe
(
a
,
a
,
...
@@ -435,108 +537,18 @@ def test_fused_marlin_moe(
...
@@ -435,108 +537,18 @@ def test_fused_marlin_moe(
topk_ids
,
topk_ids
,
global_num_experts
=
e
,
global_num_experts
=
e
,
expert_map
=
e_map
,
expert_map
=
e_map
,
global_scale1
=
global_scale1
,
global_scale2
=
global_scale2
,
g_idx1
=
g_idx1
,
g_idx1
=
g_idx1
,
g_idx2
=
g_idx2
,
g_idx2
=
g_idx2
,
sort_indices1
=
sort_indices1
,
sort_indices1
=
sort_indices1
,
sort_indices2
=
sort_indices2
,
sort_indices2
=
sort_indices2
,
w1_zeros
=
zeros1
,
w1_zeros
=
zeros1
,
w2_zeros
=
zeros2
,
w2_zeros
=
zeros2
,
num_bits
=
num_bits
,
quant_type_id
=
quant_type
.
id
,
is_k_full
=
is_k_full
)
is_k_full
=
is_k_full
)
torch
.
testing
.
assert_close
(
marlin_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
marlin_output
,
torch_output
,
atol
=
5e-2
,
rtol
=
0
)
@
pytest
.
mark
.
skip
(
"This test is here for the sake of debugging, "
"don't run it in automated tests."
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
33
,
123
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
256
,
2048
])
@
pytest
.
mark
.
parametrize
(
"e"
,
[
4
,
12
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
-
1
,
32
,
128
])
@
pytest
.
mark
.
parametrize
(
"act_order"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"has_zp"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"is_k_full"
,
[
True
,
False
])
def
test_single_marlin_moe_multiply
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
group_size
:
int
,
act_order
:
bool
,
num_bits
:
int
,
has_zp
:
bool
,
is_k_full
:
bool
):
# Filter act_order
if
act_order
:
if
group_size
==
-
1
:
return
if
group_size
in
(
k
,
n
):
return
if
has_zp
:
return
else
:
if
not
is_k_full
:
return
if
has_zp
:
quant_type
=
scalar_types
.
uint4
if
num_bits
==
4
else
scalar_types
.
uint8
else
:
quant_type
=
scalar_types
.
uint4b8
\
if
num_bits
==
4
else
scalar_types
.
uint8b128
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w
=
torch
.
randn
((
e
,
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w_ref_l
=
[]
qweight_l
=
[]
scales_l
=
[]
zeros_l
=
[]
g_idx_l
=
[]
sort_indices_l
=
[]
for
i
in
range
(
w
.
shape
[
0
]):
if
has_zp
:
w_ref
,
qweight
,
scales
,
zeros
=
awq_marlin_quantize
(
w
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
)
w_ref_l
.
append
(
w_ref
.
T
)
qweight_l
.
append
(
qweight
)
scales_l
.
append
(
scales
)
zeros_l
.
append
(
zeros
)
else
:
test_perm
=
torch
.
randperm
(
k
)
w_ref
,
qweight
,
scales
,
g_idx
,
sort_indices
,
_
=
marlin_quantize
(
w
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
,
act_order
,
test_perm
)
w_ref_l
.
append
(
w_ref
.
T
)
qweight_l
.
append
(
qweight
)
scales_l
.
append
(
scales
)
g_idx_l
.
append
(
g_idx
)
sort_indices_l
.
append
(
sort_indices
)
w_ref
=
stack_and_dev
(
w_ref_l
)
qweight
=
stack_and_dev
(
qweight_l
).
contiguous
()
scales
=
stack_and_dev
(
scales_l
)
g_idx
=
stack_and_dev
(
g_idx_l
)
if
g_idx_l
else
None
zeros
=
stack_and_dev
(
zeros_l
)
if
zeros_l
else
None
sort_indices
=
stack_and_dev
(
sort_indices_l
)
if
sort_indices_l
else
None
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
marlin_output
=
torch
.
ops
.
vllm
.
single_marlin_moe
(
a
,
qweight
,
scales
,
score
,
topk
,
renormalize
=
False
,
g_idx
=
g_idx
,
sort_indices
=
sort_indices
,
w_zeros
=
zeros
,
num_bits
=
num_bits
,
is_k_full
=
is_k_full
,
)
torch_output
=
torch_moe_single
(
a
,
w_ref
,
score
,
topk
)
torch
.
testing
.
assert_close
(
marlin_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
def
test_moe_align_block_size_opcheck
():
def
test_moe_align_block_size_opcheck
():
...
...
tests/kernels/moe/test_moe_permute_unpermute.py
0 → 100644
View file @
7a985548
# SPDX-License-Identifier: Apache-2.0
"""Tests for the MOE permute/unpermute kernel
Run `pytest tests/kernels/test_moe_permute_unpermute.py`.
"""
from
typing
import
Optional
import
numpy
as
np
import
pytest
import
torch
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.layer
import
determine_expert_map
from
vllm.model_executor.layers.fused_moe.moe_permute_unpermute
import
(
moe_permute
,
moe_unpermute
)
from
vllm.platforms
import
current_platform
NUM_EXPERTS
=
[
16
,
64
]
TOP_KS
=
[
2
,
4
,
6
,
8
]
EP_SIZE
=
[
1
,
4
,
16
]
current_platform
.
seed_everything
(
0
)
def
torch_permute
(
hidden_states
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
topk
:
int
,
n_expert
:
int
,
n_local_expert
:
int
,
start_expert
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
align_block_size
:
Optional
[
int
]
=
None
,
fill_invalid_expert
:
int
=
-
1
)
->
list
[
torch
.
Tensor
]:
n_token
,
n_hidden
=
hidden_states
.
shape
[
0
],
hidden_states
.
shape
[
1
]
if
expert_map
is
not
None
:
is_local_expert
=
(
expert_map
[
topk_ids
]
!=
-
1
)
not_local_expert
=
(
expert_map
[
topk_ids
]
==
-
1
)
topk_ids
=
is_local_expert
*
(
topk_ids
-
start_expert
)
+
not_local_expert
*
(
topk_ids
+
n_expert
)
sorted_topk_ids
,
sorted_indices
=
torch
.
sort
(
topk_ids
.
flatten
(),
stable
=
True
)
dst_row_id2src_row_id_map
=
token_expert_indices
.
flatten
()[
sorted_indices
]
expert_first_token_offset
=
torch
.
zeros
(
n_local_expert
+
1
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
idx
=
0
for
i
in
range
(
0
,
n_local_expert
):
cnt
=
0
while
idx
<
sorted_topk_ids
.
numel
()
and
sorted_topk_ids
[
idx
]
==
i
:
cnt
+=
1
idx
+=
1
expert_first_token_offset
[
i
+
1
]
=
expert_first_token_offset
[
i
]
+
cnt
_
,
src2dst_idx
=
torch
.
sort
(
dst_row_id2src_row_id_map
)
valid_row_idx
=
[]
if
align_block_size
is
None
:
permuted_hidden_states
=
hidden_states
[
dst_row_id2src_row_id_map
%
n_token
,
...]
permuted_row_size
=
permuted_hidden_states
.
shape
[
0
]
m_indices
=
torch
.
empty
(
permuted_row_size
,
device
=
"cuda"
,
dtype
=
torch
.
int32
).
fill_
(
fill_invalid_expert
)
for
i
in
range
(
1
,
n_local_expert
+
1
):
first_token_offset
=
expert_first_token_offset
[
i
-
1
]
last_token_offset
=
expert_first_token_offset
[
i
]
m_indices
[
first_token_offset
:
last_token_offset
]
=
i
-
1
src_row_id2dst_row_id_map
=
torch
.
arange
(
0
,
n_token
*
topk
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)[
src2dst_idx
].
reshape
((
n_token
,
topk
))
valid_row_idx
+=
[
i
for
i
in
range
(
expert_first_token_offset
[
-
1
])]
return
[
permuted_hidden_states
,
expert_first_token_offset
,
src_row_id2dst_row_id_map
,
m_indices
,
valid_row_idx
]
else
:
permuted_row_size
=
(
topk
*
n_token
+
n_expert
*
(
align_block_size
-
1
)
+
align_block_size
-
1
)
//
align_block_size
*
align_block_size
permuted_hidden_states
=
torch
.
empty
((
permuted_row_size
,
n_hidden
),
device
=
"cuda"
,
dtype
=
hidden_states
.
dtype
)
align_src_row_id2dst_row_id
=
torch
.
empty
(
n_token
*
topk
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
align_expert_first_token_offset
=
torch
.
zeros_like
(
expert_first_token_offset
)
m_indices
=
torch
.
empty
(
permuted_row_size
,
device
=
"cuda"
,
dtype
=
torch
.
int32
).
fill_
(
fill_invalid_expert
)
# get align_permuted_hidden_states,
# valid row_idx and align_expert_first_token_offset
for
i
in
range
(
1
,
n_local_expert
+
1
):
first_token_offset
=
expert_first_token_offset
[
i
-
1
]
last_token_offset
=
expert_first_token_offset
[
i
]
n_token_in_expert
=
last_token_offset
-
first_token_offset
align_expert_first_token_offset
[
i
]
=
align_expert_first_token_offset
[
i
-
1
]
+
(
n_token_in_expert
+
align_block_size
-
1
)
//
align_block_size
*
align_block_size
align_first_token_offset
=
align_expert_first_token_offset
[
i
-
1
]
align_last_token_offset
=
align_expert_first_token_offset
[
i
]
dst_row_id2src_row_id_in_expert
=
dst_row_id2src_row_id_map
[
first_token_offset
:
first_token_offset
+
n_token_in_expert
]
%
n_token
# store token in current expert with align_first_token_offset
permuted_hidden_states
[
align_first_token_offset
:
\
align_first_token_offset
+
n_token_in_expert
,
\
...]
=
hidden_states
[
\
dst_row_id2src_row_id_in_expert
,
...]
# set current expert m_indices
m_indices
[
align_first_token_offset
:
align_last_token_offset
]
=
i
-
1
valid_row_idx
+=
[
i
for
i
in
range
(
align_first_token_offset
,
align_first_token_offset
+
n_token_in_expert
)
]
# get align_src_row_id2dst_row_id
for
i
in
range
(
n_token
*
topk
):
eid
=
sorted_topk_ids
[
i
]
if
(
eid
>=
n_local_expert
):
# check token not in local expert
align_src_row_id2dst_row_id
[
i
]
=
align_expert_first_token_offset
[
-
1
]
continue
first_token_offset
=
expert_first_token_offset
[
eid
]
align_first_token_offset
=
align_expert_first_token_offset
[
eid
]
token_offset
=
i
-
first_token_offset
align_src_row_id2dst_row_id
[
i
]
=
align_first_token_offset
+
token_offset
align_src_row_id2dst_row_id
=
align_src_row_id2dst_row_id
[
\
src2dst_idx
].
reshape
((
n_token
,
topk
))
return
[
permuted_hidden_states
,
align_expert_first_token_offset
,
align_src_row_id2dst_row_id
,
m_indices
,
valid_row_idx
]
def
torch_unpermute
(
permuted_hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
src_row_id2dst_row_id_map
:
torch
.
Tensor
,
valid_row_idx
:
torch
.
Tensor
,
topk
:
int
,
n_expert
:
int
)
->
torch
.
Tensor
:
# ignore invalid row
mask
=
torch
.
zeros
(
permuted_hidden_states
.
shape
[
0
],
dtype
=
bool
,
device
=
"cuda"
)
mask
[
valid_row_idx
]
=
True
permuted_hidden_states
[
~
mask
]
=
0
idx
=
src_row_id2dst_row_id_map
.
flatten
()[
token_expert_indices
.
flatten
()].
reshape
(
token_expert_indices
.
shape
)
output
=
permuted_hidden_states
[
idx
,
...]
*
topk_weights
[...,
None
]
output
=
output
.
sum
(
dim
=
1
).
to
(
permuted_hidden_states
.
dtype
)
return
output
@
pytest
.
mark
.
parametrize
(
"n_token"
,
[
1
,
33
,
64
,
222
,
1024
,
2048
,
3000
,
5000
])
@
pytest
.
mark
.
parametrize
(
"n_hidden"
,
[
2048
,
4096
,
7168
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
EP_SIZE
)
@
pytest
.
mark
.
parametrize
(
"align_block_size"
,
[
None
,
128
])
def
test_moe_permute_unpermute
(
n_token
:
int
,
n_hidden
:
int
,
topk
:
int
,
n_expert
:
int
,
ep_size
:
int
,
dtype
:
torch
.
dtype
,
align_block_size
:
Optional
[
int
]):
fill_invalid_expert
=
0
ep_rank
=
np
.
random
.
randint
(
0
,
ep_size
)
expert_map
=
None
n_local_expert
=
n_expert
if
(
ep_size
!=
1
):
n_local_expert
,
expert_map
=
determine_expert_map
(
ep_size
,
ep_rank
,
n_expert
)
expert_map
=
expert_map
.
cuda
()
start_expert
=
n_local_expert
*
ep_rank
current_platform
.
seed_everything
(
0
)
hidden_states
=
torch
.
randn
((
n_token
,
n_hidden
),
device
=
"cuda"
).
to
(
dtype
)
gating_output
=
torch
.
randn
((
n_token
,
n_expert
),
device
=
"cuda"
).
to
(
dtype
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
False
)
gold0
,
gold1
,
gold2
,
gold3
,
valid_row_idx
=
torch_permute
(
hidden_states
,
topk_ids
,
token_expert_indices
,
topk
,
n_expert
,
n_local_expert
,
start_expert
,
expert_map
=
expert_map
,
align_block_size
=
align_block_size
,
fill_invalid_expert
=
fill_invalid_expert
)
result0
,
result1
,
result2
,
result3
=
moe_permute
(
hidden_states
,
topk_weights
,
topk_ids
,
token_expert_indices
,
topk
,
n_expert
,
n_local_expert
,
expert_map
,
align_block_size
,
fill_invalid_expert
)
# check expert_first_token_offset
torch
.
testing
.
assert_close
(
gold1
,
result1
,
atol
=
0
,
rtol
=
0
)
# check src_row_id2dst_row_id_map
torch
.
testing
.
assert_close
(
gold2
,
result2
,
atol
=
0
,
rtol
=
0
)
# check mindice
torch
.
testing
.
assert_close
(
gold3
,
result3
,
atol
=
0
,
rtol
=
0
)
# check permuted_hidden_states, only valid token
torch
.
testing
.
assert_close
(
gold0
[
valid_row_idx
],
result0
[
valid_row_idx
],
atol
=
0
,
rtol
=
0
)
# add a random tensor to simulate group gemm
result0
=
0.5
*
result0
+
torch
.
randn_like
(
result0
)
result4
=
moe_unpermute
(
result0
,
topk_weights
,
topk_ids
,
result2
,
result1
,
topk
,
n_expert
,
n_local_expert
)
gold4
=
torch_unpermute
(
result0
,
topk_weights
,
topk_ids
,
token_expert_indices
,
result2
,
valid_row_idx
,
topk
,
n_local_expert
)
# check unpermuted hidden
torch
.
testing
.
assert_close
(
result4
,
gold4
,
atol
=
2e-2
,
rtol
=
0
)
Prev
1
…
13
14
15
16
17
18
19
20
21
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment