Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
711aa9d5
Commit
711aa9d5
authored
Jul 30, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.0' into v0.10.0-dev
parents
751c492c
6d8d0a24
Changes
519
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2419 additions
and
177 deletions
+2419
-177
tests/entrypoints/openai/test_tokenization.py
tests/entrypoints/openai/test_tokenization.py
+104
-0
tests/entrypoints/openai/test_transcription_validation.py
tests/entrypoints/openai/test_transcription_validation.py
+24
-7
tests/entrypoints/openai/test_translation_validation.py
tests/entrypoints/openai/test_translation_validation.py
+3
-4
tests/entrypoints/openai/test_vision.py
tests/entrypoints/openai/test_vision.py
+2
-2
tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py
...ints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py
+153
-0
tests/entrypoints/test_chat_utils.py
tests/entrypoints/test_chat_utils.py
+520
-2
tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py
...nels/attention/test_flashinfer_trtllm_decode_attention.py
+140
-0
tests/kernels/attention/test_rocm_attention_selector.py
tests/kernels/attention/test_rocm_attention_selector.py
+26
-10
tests/kernels/attention/untest_flashinfer.py
tests/kernels/attention/untest_flashinfer.py
+31
-18
tests/kernels/core/test_layernorm.py
tests/kernels/core/test_layernorm.py
+20
-8
tests/kernels/mamba/test_mamba_mixer2.py
tests/kernels/mamba/test_mamba_mixer2.py
+5
-4
tests/kernels/mamba/test_mamba_ssm_ssd.py
tests/kernels/mamba/test_mamba_ssm_ssd.py
+22
-8
tests/kernels/mamba/untest_causal_conv1d.py
tests/kernels/mamba/untest_causal_conv1d.py
+44
-114
tests/kernels/moe/modular_kernel_tools/__init__.py
tests/kernels/moe/modular_kernel_tools/__init__.py
+0
-0
tests/kernels/moe/modular_kernel_tools/cli_args.py
tests/kernels/moe/modular_kernel_tools/cli_args.py
+159
-0
tests/kernels/moe/modular_kernel_tools/common.py
tests/kernels/moe/modular_kernel_tools/common.py
+641
-0
tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py
...s/kernels/moe/modular_kernel_tools/make_feature_matrix.py
+173
-0
tests/kernels/moe/modular_kernel_tools/mk_objects.py
tests/kernels/moe/modular_kernel_tools/mk_objects.py
+87
-0
tests/kernels/moe/modular_kernel_tools/parallel_utils.py
tests/kernels/moe/modular_kernel_tools/parallel_utils.py
+138
-0
tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py
...ernels/moe/modular_kernel_tools/profile_modular_kernel.py
+127
-0
No files found.
Too many changes to show.
To preserve performance only
519 of 519+
files are displayed.
Plain diff
Email patch
tests/entrypoints/openai/test_tokenization.py
View file @
711aa9d5
...
@@ -34,6 +34,7 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811
...
@@ -34,6 +34,7 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811
f
"zephyr-lora2=
{
zephyr_lora_added_tokens_files
}
"
,
f
"zephyr-lora2=
{
zephyr_lora_added_tokens_files
}
"
,
"--max-lora-rank"
,
"--max-lora-rank"
,
"64"
,
"64"
,
"--enable-tokenizer-info-endpoint"
,
]
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
...
@@ -285,3 +286,106 @@ async def test_detokenize(
...
@@ -285,3 +286,106 @@ async def test_detokenize(
response
.
raise_for_status
()
response
.
raise_for_status
()
assert
response
.
json
()
==
{
"prompt"
:
prompt
}
assert
response
.
json
()
==
{
"prompt"
:
prompt
}
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name,tokenizer_name"
,
[(
MODEL_NAME
,
MODEL_NAME
),
(
"zephyr-lora2"
,
"zephyr-lora2"
)],
indirect
=
[
"tokenizer_name"
],
)
async
def
test_tokenizer_info_basic
(
server
:
RemoteOpenAIServer
,
model_name
:
str
,
tokenizer_name
:
str
,
):
"""Test basic tokenizer info endpoint functionality."""
response
=
requests
.
get
(
server
.
url_for
(
"tokenizer_info"
))
response
.
raise_for_status
()
result
=
response
.
json
()
assert
"tokenizer_class"
in
result
assert
isinstance
(
result
[
"tokenizer_class"
],
str
)
assert
result
[
"tokenizer_class"
]
@
pytest
.
mark
.
asyncio
async
def
test_tokenizer_info_schema
(
server
:
RemoteOpenAIServer
):
"""Test that the response matches expected schema types."""
response
=
requests
.
get
(
server
.
url_for
(
"tokenizer_info"
))
response
.
raise_for_status
()
result
=
response
.
json
()
field_types
=
{
"add_bos_token"
:
bool
,
"add_prefix_space"
:
bool
,
"clean_up_tokenization_spaces"
:
bool
,
"split_special_tokens"
:
bool
,
"bos_token"
:
str
,
"eos_token"
:
str
,
"pad_token"
:
str
,
"unk_token"
:
str
,
"chat_template"
:
str
,
"errors"
:
str
,
"model_max_length"
:
int
,
"additional_special_tokens"
:
list
,
"added_tokens_decoder"
:
dict
,
}
for
field
,
expected_type
in
field_types
.
items
():
if
field
in
result
and
result
[
field
]
is
not
None
:
assert
isinstance
(
result
[
field
],
expected_type
),
(
f
"
{
field
}
should be
{
expected_type
.
__name__
}
"
)
@
pytest
.
mark
.
asyncio
async
def
test_tokenizer_info_added_tokens_structure
(
server
:
RemoteOpenAIServer
,
):
"""Test added_tokens_decoder structure if present."""
response
=
requests
.
get
(
server
.
url_for
(
"tokenizer_info"
))
response
.
raise_for_status
()
result
=
response
.
json
()
added_tokens
=
result
.
get
(
"added_tokens_decoder"
)
if
added_tokens
:
for
token_id
,
token_info
in
added_tokens
.
items
():
assert
isinstance
(
token_id
,
str
),
"Token IDs should be strings"
assert
isinstance
(
token_info
,
dict
),
"Token info should be a dict"
assert
"content"
in
token_info
,
"Token info should have content"
assert
"special"
in
token_info
,
(
"Token info should have special flag"
)
assert
isinstance
(
token_info
[
"special"
],
bool
),
(
"Special flag should be boolean"
)
@
pytest
.
mark
.
asyncio
async
def
test_tokenizer_info_consistency_with_tokenize
(
server
:
RemoteOpenAIServer
,
):
"""Test that tokenizer info is consistent with tokenization endpoint."""
info_response
=
requests
.
get
(
server
.
url_for
(
"tokenizer_info"
))
info_response
.
raise_for_status
()
info
=
info_response
.
json
()
tokenize_response
=
requests
.
post
(
server
.
url_for
(
"tokenize"
),
json
=
{
"model"
:
MODEL_NAME
,
"prompt"
:
"Hello world!"
},
)
tokenize_response
.
raise_for_status
()
tokenize_result
=
tokenize_response
.
json
()
info_max_len
=
info
.
get
(
"model_max_length"
)
tokenize_max_len
=
tokenize_result
.
get
(
"max_model_len"
)
if
info_max_len
and
tokenize_max_len
:
assert
info_max_len
>=
tokenize_max_len
,
(
"Info max length should be >= tokenize max length"
)
@
pytest
.
mark
.
asyncio
async
def
test_tokenizer_info_chat_template
(
server
:
RemoteOpenAIServer
):
"""Test chat template is properly included."""
response
=
requests
.
get
(
server
.
url_for
(
"tokenizer_info"
))
response
.
raise_for_status
()
result
=
response
.
json
()
chat_template
=
result
.
get
(
"chat_template"
)
if
chat_template
:
assert
isinstance
(
chat_template
,
str
),
(
"Chat template should be a string"
)
assert
chat_template
.
strip
(),
"Chat template should not be empty"
\ No newline at end of file
tests/entrypoints/openai/test_transcription_validation.py
View file @
711aa9d5
...
@@ -17,6 +17,11 @@ from vllm.assets.audio import AudioAsset
...
@@ -17,6 +17,11 @@ from vllm.assets.audio import AudioAsset
from
...utils
import
RemoteOpenAIServer
from
...utils
import
RemoteOpenAIServer
MISTRAL_FORMAT_ARGS
=
[
"--tokenizer_mode"
,
"mistral"
,
"--config_format"
,
"mistral"
,
"--load_format"
,
"mistral"
]
@
pytest
.
fixture
@
pytest
.
fixture
def
mary_had_lamb
():
def
mary_had_lamb
():
...
@@ -33,9 +38,15 @@ def winning_call():
...
@@ -33,9 +38,15 @@ def winning_call():
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_basic_audio
(
mary_had_lamb
):
@
pytest
.
mark
.
parametrize
(
model_name
=
"openai/whisper-large-v3-turbo"
"model_name"
,
[
"openai/whisper-large-v3-turbo"
,
"mistralai/Voxtral-Mini-3B-2507"
])
async
def
test_basic_audio
(
mary_had_lamb
,
model_name
):
server_args
=
[
"--enforce-eager"
]
server_args
=
[
"--enforce-eager"
]
if
model_name
.
startswith
(
"mistralai"
):
server_args
+=
MISTRAL_FORMAT_ARGS
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
with
RemoteOpenAIServer
(
model_name
,
server_args
)
as
remote_server
:
with
RemoteOpenAIServer
(
model_name
,
server_args
)
as
remote_server
:
client
=
remote_server
.
get_async_client
()
client
=
remote_server
.
get_async_client
()
...
@@ -65,10 +76,13 @@ async def test_bad_requests(mary_had_lamb):
...
@@ -65,10 +76,13 @@ async def test_bad_requests(mary_had_lamb):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_long_audio_request
(
mary_had_lamb
):
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"openai/whisper-large-v3-turbo"
])
model_name
=
"openai/whisper-large-v3-turbo"
async
def
test_long_audio_request
(
mary_had_lamb
,
model_name
):
server_args
=
[
"--enforce-eager"
]
server_args
=
[
"--enforce-eager"
]
if
model_name
.
startswith
(
"openai"
):
return
mary_had_lamb
.
seek
(
0
)
mary_had_lamb
.
seek
(
0
)
audio
,
sr
=
librosa
.
load
(
mary_had_lamb
)
audio
,
sr
=
librosa
.
load
(
mary_had_lamb
)
# Add small silence after each audio for repeatability in the split process
# Add small silence after each audio for repeatability in the split process
...
@@ -87,7 +101,8 @@ async def test_long_audio_request(mary_had_lamb):
...
@@ -87,7 +101,8 @@ async def test_long_audio_request(mary_had_lamb):
response_format
=
"text"
,
response_format
=
"text"
,
temperature
=
0.0
)
temperature
=
0.0
)
out
=
json
.
loads
(
transcription
)[
'text'
]
out
=
json
.
loads
(
transcription
)[
'text'
]
assert
out
.
count
(
"Mary had a little lamb"
)
==
10
counts
=
out
.
count
(
"Mary had a little lamb"
)
assert
counts
==
10
,
counts
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
...
@@ -154,7 +169,8 @@ async def test_streaming_response(winning_call):
...
@@ -154,7 +169,8 @@ async def test_streaming_response(winning_call):
file
=
winning_call
,
file
=
winning_call
,
language
=
"en"
,
language
=
"en"
,
temperature
=
0.0
,
temperature
=
0.0
,
extra_body
=
dict
(
stream
=
True
))
extra_body
=
dict
(
stream
=
True
),
timeout
=
30
)
# Reconstruct from chunks and validate
# Reconstruct from chunks and validate
async
for
chunk
in
res
:
async
for
chunk
in
res
:
# just a chunk
# just a chunk
...
@@ -184,7 +200,8 @@ async def test_stream_options(winning_call):
...
@@ -184,7 +200,8 @@ async def test_stream_options(winning_call):
temperature
=
0.0
,
temperature
=
0.0
,
extra_body
=
dict
(
stream
=
True
,
extra_body
=
dict
(
stream
=
True
,
stream_include_usage
=
True
,
stream_include_usage
=
True
,
stream_continuous_usage_stats
=
True
))
stream_continuous_usage_stats
=
True
),
timeout
=
30
)
final
=
False
final
=
False
continuous
=
True
continuous
=
True
async
for
chunk
in
res
:
async
for
chunk
in
res
:
...
...
tests/entrypoints/openai/test_translation_validation.py
View file @
711aa9d5
...
@@ -39,8 +39,8 @@ async def test_basic_audio(foscolo):
...
@@ -39,8 +39,8 @@ async def test_basic_audio(foscolo):
# TODO remove once language detection is implemented
# TODO remove once language detection is implemented
extra_body
=
dict
(
language
=
"it"
),
extra_body
=
dict
(
language
=
"it"
),
temperature
=
0.0
)
temperature
=
0.0
)
out
=
json
.
loads
(
translation
)[
'text'
].
strip
()
out
=
json
.
loads
(
translation
)[
'text'
].
strip
()
.
lower
()
assert
"
Nor will I ever touch the sacred
"
in
out
assert
"
greek sea
"
in
out
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
...
@@ -168,5 +168,4 @@ async def test_long_audio_request(foscolo):
...
@@ -168,5 +168,4 @@ async def test_long_audio_request(foscolo):
response_format
=
"text"
,
response_format
=
"text"
,
temperature
=
0.0
)
temperature
=
0.0
)
out
=
json
.
loads
(
translation
)[
'text'
].
strip
().
lower
()
out
=
json
.
loads
(
translation
)[
'text'
].
strip
().
lower
()
# TODO investigate higher model uncertainty in for longer translations.
assert
out
.
count
(
"greek sea"
)
==
2
assert
out
.
count
(
"nor will i ever"
)
==
2
tests/entrypoints/openai/test_vision.py
View file @
711aa9d5
...
@@ -45,11 +45,11 @@ EXPECTED_MM_BEAM_SEARCH_RES = [
...
@@ -45,11 +45,11 @@ EXPECTED_MM_BEAM_SEARCH_RES = [
],
],
[
[
"The image shows a Venn diagram with three over"
,
"The image shows a Venn diagram with three over"
,
"Th
is
image shows a Venn diagram with three
over
"
,
"Th
e
image shows a Venn diagram with three
intersect
"
,
],
],
[
[
"This image displays a gradient of colors ranging from"
,
"This image displays a gradient of colors ranging from"
,
"Th
is
image displays a gradient of colors
t
ran
sition
ing from"
,
"Th
e
image displays a gradient of colors ran
g
ing from"
,
],
],
]
]
...
...
tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py
0 → 100644
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import
json
from
unittest.mock
import
MagicMock
import
pytest
from
tests.entrypoints.openai.tool_parsers.utils
import
(
run_tool_extraction
,
run_tool_extraction_streaming
)
from
vllm.entrypoints.openai.protocol
import
FunctionCall
,
ToolCall
from
vllm.entrypoints.openai.tool_parsers
import
ToolParser
,
ToolParserManager
def
make_tool_call
(
name
,
arguments
):
return
ToolCall
(
type
=
"function"
,
function
=
FunctionCall
(
name
=
name
,
arguments
=
json
.
dumps
(
arguments
)))
# TODO: add reason prefix and suffix.
@
pytest
.
mark
.
parametrize
(
"model_output,expected_tool_calls,expected_content"
,
[
# No tool call
(
"How can I help you today?"
,
[],
"How can I help you today?"
),
# Single tool call, no content
(
"<tool_calls>[{
\"
name
\"
:
\"
get_weather
\"
,
\"
arguments
\"
: {
\"
city
\"
:
\"
San Francisco
\"
,
\"
metric
\"
:
\"
celsius
\"
}}]</tool_calls>"
,
#noqa: E501
[
make_tool_call
(
"get_weather"
,
{
"city"
:
"San Francisco"
,
"metric"
:
"celsius"
})
],
None
),
# Multiple tool calls
(
"<tool_calls>[{
\"
name
\"
:
\"
get_weather
\"
,
\"
arguments
\"
: {
\"
city
\"
:
\"
San Francisco
\"
,
\"
metric
\"
:
\"
celsius
\"
}}, {
\"
name
\"
:
\"
register_user
\"
,
\"
arguments
\"
: {
\"
name
\"
:
\"
John Doe
\"
,
\"
age
\"
: 37,
\"
address
\"
: {
\"
city
\"
:
\"
San Francisco
\"
,
\"
state
\"
:
\"
CA
\"
},
\"
role
\"
: null,
\"
passed_test
\"
: true,
\"
aliases
\"
: [
\"
John
\"
,
\"
Johnny
\"
]}}]</tool_calls>"
,
#noqa: E501
[
make_tool_call
(
"get_weather"
,
{
"city"
:
"San Francisco"
,
"metric"
:
"celsius"
}),
make_tool_call
(
"register_user"
,
{
"name"
:
"John Doe"
,
"age"
:
37
,
"address"
:
{
"city"
:
"San Francisco"
,
"state"
:
"CA"
},
"role"
:
None
,
"passed_test"
:
True
,
"aliases"
:
[
"John"
,
"Johnny"
]
})
],
None
),
# Content before tool call
(
"I will call the tool now. <tool_calls>[{
\"
name
\"
:
\"
get_weather
\"
,
\"
arguments
\"
: {
\"
city
\"
:
\"
Boston
\"
}}]</tool_calls>"
,
#noqa: E501
[
make_tool_call
(
"get_weather"
,
{
"city"
:
"Boston"
})],
"I will call the tool now. "
),
# Content after tool call (should be stripped)
(
"<tool_calls>[{
\"
name
\"
:
\"
get_weather
\"
,
\"
arguments
\"
: {
\"
city
\"
:
\"
Seattle
\"
}}]</tool_calls>
\n
Thank you!"
,
#noqa: E501
[
make_tool_call
(
"get_weather"
,
{
"city"
:
"Seattle"
})],
None
),
(
"<tool_calls>[{
\"
name
\"
:
\"
complex_tool
\"
,
\"
arguments
\"
: {
\"
level1
\"
: {
\"
level2
\"
: {
\"
level3
\"
: {
\"
value
\"
: 123}}}}}]</tool_calls>"
,
[
make_tool_call
(
"complex_tool"
,
{
"level1"
:
{
"level2"
:
{
"level3"
:
{
"value"
:
123
}
}
}})
],
None
,
),
])
def
test_hunyuan_a13b_tool_parser_extract
(
model_output
,
expected_tool_calls
,
expected_content
):
mock_tokenizer
=
MagicMock
()
tool_parser
:
ToolParser
=
ToolParserManager
.
get_tool_parser
(
"hunyuan_a13b"
)(
mock_tokenizer
)
content
,
tool_calls
=
run_tool_extraction
(
tool_parser
,
model_output
,
streaming
=
False
)
# align the random id.
for
idx
in
range
(
len
(
tool_calls
)):
tool_calls
[
idx
].
id
=
expected_tool_calls
[
idx
].
id
assert
tool_calls
==
expected_tool_calls
assert
content
==
expected_content
# Streaming test: simulate incremental output
@
pytest
.
mark
.
parametrize
(
"model_deltas,expected_tool_calls"
,
[
([
"<tool_calls>[{
\"
name
\"
:
\"
get_weather
\"
, "
,
"
\"
arguments
\"
: {
\"
city
\"
:
\"
San Francisco
\"
, "
,
"
\"
metric
\"
:
\"
celsius
\"
}}]"
,
"</tool_calls>"
],
[
make_tool_call
(
"get_weather"
,
{
"city"
:
"San Francisco"
,
"metric"
:
"celsius"
})
]),
([
"<tool_calls>[{
\"
name
\"
:"
,
"
\"
get_weather
\"
,"
,
"
\"
arguments
\"
:"
,
" {
\"
city
\"
:
\"
Boston
\"
}"
,
"}]"
,
"</tool_calls>"
],
[
make_tool_call
(
"get_weather"
,
{
"city"
:
"Boston"
})]),
([
""
,
"<tool_calls>[{
\"
name
\"
:"
,
"
\"
get_weather
\"
,"
,
"
\"
arguments
\"
:"
,
" {
\"
city
\"
:
\"
Boston
\"
}"
,
"}]"
,
"</tool_calls>"
,
"
\n
</answer>"
],
[
make_tool_call
(
"get_weather"
,
{
"city"
:
"Boston"
})]),
pytest
.
param
([
"<tool_calls>[{
\"
name
\"
:
\"
complex_tool
\"
,"
,
"
\"
arguments
\"
: "
,
" {
\"
level1
\"
: {
\"
level2
\"
: "
,
"{
\"
level3
\"
: {
\"
value
\"
: 123}}}}}"
,
"]</tool_calls>"
],
[
make_tool_call
(
"complex_tool"
,
{
"level1"
:
{
"level2"
:
{
"level3"
:
{
"value"
:
123
}
}
}})
],
marks
=
pytest
.
mark
.
xfail
(
reason
=
"stream parsing not support nested json yet."
)),
])
def
test_hunyuan_a13b_tool_parser_streaming
(
model_deltas
,
expected_tool_calls
):
mock_tokenizer
=
MagicMock
()
tool_parser
:
ToolParser
=
ToolParserManager
.
get_tool_parser
(
"hunyuan_a13b"
)(
mock_tokenizer
)
reconstructor
=
run_tool_extraction_streaming
(
tool_parser
,
model_deltas
,
assert_one_tool_per_delta
=
False
)
# align the random id.
for
idx
in
range
(
len
(
reconstructor
.
tool_calls
)):
reconstructor
.
tool_calls
[
idx
].
id
=
expected_tool_calls
[
idx
].
id
assert
reconstructor
.
tool_calls
==
expected_tool_calls
tests/entrypoints/test_chat_utils.py
View file @
711aa9d5
...
@@ -2,12 +2,20 @@
...
@@ -2,12 +2,20 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
warnings
import
warnings
from
typing
import
Optional
from
collections.abc
import
Mapping
from
typing
import
Literal
,
Optional
import
pytest
import
pytest
import
os
import
os
from
mistral_common.tokens.tokenizers.base
import
(
SpecialTokenPolicy
,
SpecialTokens
)
from
mistral_common.tokens.tokenizers.tekken
import
(
SpecialTokenInfo
,
Tekkenizer
)
from
vllm.assets.audio
import
AudioAsset
from
vllm.assets.image
import
ImageAsset
from
vllm.assets.image
import
ImageAsset
from
vllm.assets.video
import
VideoAsset
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.entrypoints.chat_utils
import
(
_try_extract_ast
,
load_chat_template
,
from
vllm.entrypoints.chat_utils
import
(
_try_extract_ast
,
load_chat_template
,
parse_chat_messages
,
parse_chat_messages
,
...
@@ -16,9 +24,12 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
...
@@ -16,9 +24,12 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
resolve_hf_chat_template
)
resolve_hf_chat_template
)
from
vllm.entrypoints.llm
import
apply_hf_chat_template
from
vllm.entrypoints.llm
import
apply_hf_chat_template
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal.utils
import
encode_image_base64
from
vllm.multimodal.utils
import
(
encode_audio_base64
,
encode_image_base64
,
encode_video_base64
)
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
..utils
import
models_path_prefix
from
..utils
import
models_path_prefix
from
vllm.transformers_utils.tokenizers.mistral
import
MistralTokenizer
from
..models.registry
import
HF_EXAMPLE_MODELS
from
..models.registry
import
HF_EXAMPLE_MODELS
from
..utils
import
VLLM_PATH
from
..utils
import
VLLM_PATH
...
@@ -30,11 +41,13 @@ ULTRAVOX_MODEL_ID = os.path.join(models_path_prefix, "fixie-ai/ultravox-v0_5-lla
...
@@ -30,11 +41,13 @@ ULTRAVOX_MODEL_ID = os.path.join(models_path_prefix, "fixie-ai/ultravox-v0_5-lla
QWEN2AUDIO_MODEL_ID
=
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen2-Audio-7B-Instruct"
)
QWEN2AUDIO_MODEL_ID
=
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen2-Audio-7B-Instruct"
)
QWEN2VL_MODEL_ID
=
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen2-VL-2B-Instruct"
)
QWEN2VL_MODEL_ID
=
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen2-VL-2B-Instruct"
)
QWEN25VL_MODEL_ID
=
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen2.5-VL-3B-Instruct"
)
QWEN25VL_MODEL_ID
=
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen2.5-VL-3B-Instruct"
)
QWEN25OMNI_MODEL_ID
=
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen2.5-Omni-7B"
)
MLLAMA_MODEL_ID
=
os
.
path
.
join
(
models_path_prefix
,
"meta-llama/Llama-3.2-11B-Vision-Instruct"
)
MLLAMA_MODEL_ID
=
os
.
path
.
join
(
models_path_prefix
,
"meta-llama/Llama-3.2-11B-Vision-Instruct"
)
LLAMA_GUARD_MODEL_ID
=
os
.
path
.
join
(
models_path_prefix
,
"meta-llama/Llama-Guard-3-1B"
)
LLAMA_GUARD_MODEL_ID
=
os
.
path
.
join
(
models_path_prefix
,
"meta-llama/Llama-Guard-3-1B"
)
HERMES_MODEL_ID
=
os
.
path
.
join
(
models_path_prefix
,
"NousResearch/Hermes-3-Llama-3.1-8B"
)
HERMES_MODEL_ID
=
os
.
path
.
join
(
models_path_prefix
,
"NousResearch/Hermes-3-Llama-3.1-8B"
)
MISTRAL_MODEL_ID
=
os
.
path
.
join
(
models_path_prefix
,
"mistralai/Mistral-Small-3.1-24B-Instruct-2503"
)
MISTRAL_MODEL_ID
=
os
.
path
.
join
(
models_path_prefix
,
"mistralai/Mistral-Small-3.1-24B-Instruct-2503"
)
@
pytest
.
fixture
(
scope
=
"function"
)
@
pytest
.
fixture
(
scope
=
"function"
)
def
phi3v_model_config
():
def
phi3v_model_config
():
return
ModelConfig
(
PHI3V_MODEL_ID
,
return
ModelConfig
(
PHI3V_MODEL_ID
,
...
@@ -49,6 +62,21 @@ def phi3v_model_config():
...
@@ -49,6 +62,21 @@ def phi3v_model_config():
})
})
@
pytest
.
fixture
(
scope
=
"function"
)
def
phi3v_model_config_mm_interleaved
():
return
ModelConfig
(
PHI3V_MODEL_ID
,
task
=
"generate"
,
tokenizer
=
PHI3V_MODEL_ID
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
True
,
dtype
=
"auto"
,
seed
=
0
,
interleave_mm_strings
=
True
,
limit_mm_per_prompt
=
{
"image"
:
2
,
})
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
phi3v_tokenizer
():
def
phi3v_tokenizer
():
return
TokenizerGroup
(
return
TokenizerGroup
(
...
@@ -59,6 +87,32 @@ def phi3v_tokenizer():
...
@@ -59,6 +87,32 @@ def phi3v_tokenizer():
)
)
@
pytest
.
fixture
(
scope
=
"function"
)
def
qwen25omni_model_config_mm_interleaved
():
return
ModelConfig
(
QWEN25OMNI_MODEL_ID
,
task
=
"generate"
,
tokenizer
=
QWEN25OMNI_MODEL_ID
,
tokenizer_mode
=
"auto"
,
dtype
=
"auto"
,
seed
=
0
,
interleave_mm_strings
=
True
,
limit_mm_per_prompt
=
{
"image"
:
2
,
"audio"
:
1
,
"video"
:
1
,
})
@
pytest
.
fixture
(
scope
=
"module"
)
def
qwen25omni_tokenizer
():
return
TokenizerGroup
(
tokenizer_id
=
QWEN25OMNI_MODEL_ID
,
enable_lora
=
False
,
max_num_seqs
=
5
,
max_input_length
=
None
,
)
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
mllama_model_config
():
def
mllama_model_config
():
return
ModelConfig
(
MLLAMA_MODEL_ID
,
return
ModelConfig
(
MLLAMA_MODEL_ID
,
...
@@ -114,6 +168,20 @@ def image_url():
...
@@ -114,6 +168,20 @@ def image_url():
return
f
"data:image/jpeg;base64,
{
base64
}
"
return
f
"data:image/jpeg;base64,
{
base64
}
"
@
pytest
.
fixture
(
scope
=
"module"
)
def
video_url
():
video
=
VideoAsset
(
'baby_reading'
,
1
)
base64
=
encode_video_base64
(
video
.
np_ndarrays
)
return
f
"data:video/jpeg;base64,
{
base64
}
"
@
pytest
.
fixture
(
scope
=
"module"
)
def
audio_url
():
audio
=
AudioAsset
(
'mary_had_lamb'
)
base64
=
encode_audio_base64
(
*
audio
.
audio_and_sample_rate
)
return
f
"data:audio/ogg;base64,
{
base64
}
"
def
_assert_mm_data_is_image_input
(
def
_assert_mm_data_is_image_input
(
mm_data
:
Optional
[
MultiModalDataDict
],
mm_data
:
Optional
[
MultiModalDataDict
],
image_count
:
int
,
image_count
:
int
,
...
@@ -127,6 +195,23 @@ def _assert_mm_data_is_image_input(
...
@@ -127,6 +195,23 @@ def _assert_mm_data_is_image_input(
assert
isinstance
(
image_data
,
list
)
and
len
(
image_data
)
==
image_count
assert
isinstance
(
image_data
,
list
)
and
len
(
image_data
)
==
image_count
ModalityType
=
Literal
[
"image"
,
"video"
,
"audio"
]
MultiModalDataCounts
=
Mapping
[
ModalityType
,
int
]
def
_assert_mm_data_inputs
(
mm_data
:
Optional
[
MultiModalDataDict
],
data_count
:
MultiModalDataCounts
,
)
->
None
:
assert
mm_data
is
not
None
assert
set
(
data_count
.
keys
())
==
(
set
(
mm_data
.
keys
()))
for
modality
,
n
in
data_count
.
items
():
modality_data
=
mm_data
.
get
(
modality
)
assert
modality_data
is
not
None
assert
isinstance
(
modality_data
,
list
)
and
len
(
modality_data
)
==
n
def
test_parse_chat_messages_single_image
(
def
test_parse_chat_messages_single_image
(
phi3v_model_config
,
phi3v_model_config
,
phi3v_tokenizer
,
phi3v_tokenizer
,
...
@@ -638,6 +723,277 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
...
@@ -638,6 +723,277 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
_assert_mm_data_is_image_input
(
mm_data
,
2
)
_assert_mm_data_is_image_input
(
mm_data
,
2
)
def
test_parse_chat_messages_multiple_images_interleave
(
phi3v_model_config_mm_interleaved
,
phi3v_tokenizer
,
image_url
,
):
conversation
,
mm_data
=
parse_chat_messages
(
[{
"role"
:
"user"
,
"content"
:
[{
"type"
:
"text"
,
"text"
:
"I need you to compare this image"
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}
},
{
"type"
:
"text"
,
"text"
:
"and this one"
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}
},
{
"type"
:
"text"
,
"text"
:
"Do they have differences?"
}]
}],
phi3v_model_config_mm_interleaved
,
phi3v_tokenizer
,
content_format
=
"string"
,
)
assert
conversation
==
[{
"role"
:
"user"
,
"content"
:
"I need you to compare this image
\n
<|image_1|>
\n
and this one
\n
<|image_2|>
\n
"
# noqa: E501
"Do they have differences?"
}]
_assert_mm_data_is_image_input
(
mm_data
,
2
)
@
pytest
.
mark
.
asyncio
async
def
test_parse_chat_messages_multiple_images_interleave_async
(
phi3v_model_config_mm_interleaved
,
phi3v_tokenizer
,
image_url
,
):
conversation
,
mm_data
=
parse_chat_messages_futures
(
[{
"role"
:
"user"
,
"content"
:
[{
"type"
:
"text"
,
"text"
:
"I need you to compare this image"
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}
},
{
"type"
:
"text"
,
"text"
:
"and this one"
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}
},
{
"type"
:
"text"
,
"text"
:
"Do they have differences?"
}]
}],
phi3v_model_config_mm_interleaved
,
phi3v_tokenizer
,
content_format
=
"string"
,
)
assert
conversation
==
[{
"role"
:
"user"
,
"content"
:
"I need you to compare this image
\n
<|image_1|>
\n
and this one
\n
<|image_2|>
\n
"
# noqa: E501
"Do they have differences?"
}]
_assert_mm_data_is_image_input
(
await
mm_data
,
2
)
def
test_parse_chat_messages_multiple_images_multiple_messages_interleave
(
phi3v_model_config_mm_interleaved
,
phi3v_tokenizer
,
image_url
,
):
conversation
,
mm_data
=
parse_chat_messages
(
[{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"text"
,
"text"
:
"What's on this image?"
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}
},
{
"type"
:
"text"
,
"text"
:
"Be accurate."
},
]
},
{
"role"
:
"assistant"
,
"content"
:
"Some stuff."
},
{
"role"
:
"user"
,
"content"
:
[{
"type"
:
"text"
,
"text"
:
"What's on this image?"
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}
}]
}],
phi3v_model_config_mm_interleaved
,
phi3v_tokenizer
,
content_format
=
"string"
,
)
assert
conversation
==
[{
"role"
:
"user"
,
"content"
:
"What's on this image?
\n
<|image_1|>
\n
Be accurate."
},
{
"role"
:
"assistant"
,
"content"
:
"Some stuff."
},
{
"role"
:
"user"
,
"content"
:
"What's on this image?
\n
<|image_2|>"
}]
_assert_mm_data_is_image_input
(
mm_data
,
2
)
def
test_parse_chat_messages_multiple_modals_multiple_messages_interleave
(
qwen25omni_model_config_mm_interleaved
,
qwen25omni_tokenizer
,
image_url
,
video_url
,
audio_url
):
conversation
,
mm_data
=
parse_chat_messages
(
[{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"text"
,
"text"
:
"What's on this image?"
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}
},
{
"type"
:
"text"
,
"text"
:
"Now listen to this audio"
},
{
"type"
:
"audio_url"
,
"audio_url"
:
{
"url"
:
audio_url
}
},
]
},
{
"role"
:
"assistant"
,
"content"
:
"Some stuff."
},
{
"role"
:
"user"
,
"content"
:
[{
"type"
:
"text"
,
"text"
:
"What's on this image?"
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}
},
{
"type"
:
"text"
,
"text"
:
"And what's in the video?"
},
{
"type"
:
"video_url"
,
"video_url"
:
{
"url"
:
video_url
}
}]
}],
qwen25omni_model_config_mm_interleaved
,
qwen25omni_tokenizer
,
content_format
=
"string"
,
)
assert
conversation
==
[{
"role"
:
"user"
,
"content"
:
"What's on this image?
\n
<|vision_start|><|IMAGE|><|vision_end|>
\n
"
"Now listen to this audio
\n
Audio 1: <|audio_bos|><|AUDIO|><|audio_eos|>"
},
{
"role"
:
"assistant"
,
"content"
:
"Some stuff."
},
{
"role"
:
"user"
,
"content"
:
"What's on this image?
\n
<|vision_start|><|IMAGE|><|vision_end|>
\n
"
"And what's in the video?
\n
<|vision_start|><|VIDEO|><|vision_end|>"
}]
_assert_mm_data_inputs
(
mm_data
,
{
"image"
:
2
,
"video"
:
1
,
"audio"
:
1
})
def
test_parse_chat_messages_multiple_images_interleave_with_placeholders
(
phi3v_model_config_mm_interleaved
,
phi3v_tokenizer
,
image_url
,
):
with
pytest
.
raises
(
ValueError
,
match
=
r
"Found more '<|image_1|>' placeholders in input prompt "
"than actual multimodal data items."
):
parse_chat_messages
(
[{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}
},
{
"type"
:
"text"
,
"text"
:
"I need you to compare this image
\n
<|image_1|>
\n
and this one
\n
<|image_2|>
\n
"
# noqa: E501
"Do they have differences?"
},
]
}],
phi3v_model_config_mm_interleaved
,
phi3v_tokenizer
,
content_format
=
"string"
,
)
### Mllama currently wraps images / texts as interleaved dictionaries
### Mllama currently wraps images / texts as interleaved dictionaries
def
test_mllama_single_image
(
def
test_mllama_single_image
(
mllama_model_config
,
mllama_model_config
,
...
@@ -1027,3 +1383,165 @@ def test_resolve_content_format_examples(template_path, expected_format):
...
@@ -1027,3 +1383,165 @@ def test_resolve_content_format_examples(template_path, expected_format):
)
)
assert
resolved_format
==
expected_format
assert
resolved_format
==
expected_format
def
test_parse_chat_messages_include_thinking_chunk
(
mistral_model_config
,
mistral_tokenizer
):
messages
=
[{
"role"
:
"system"
,
"content"
:
[{
"type"
:
"text"
,
"text"
:
"You are a helpful assistant."
},
{
"type"
:
"thinking"
,
"closed"
:
True
,
"thinking"
:
"Only return the answer when you are confident."
}]
},
{
"role"
:
"user"
,
"content"
:
"What is 2+2?"
},
{
"role"
:
"assistant"
,
"content"
:
[{
"type"
:
"text"
,
"text"
:
"Let me think about it."
},
{
"type"
:
"thinking"
,
"closed"
:
True
,
"thinking"
:
"2+2 = 4"
},
{
"type"
:
"text"
,
"text"
:
"The answer is 4."
,
}],
}]
conversation_with_thinking
,
_
=
parse_chat_messages
(
messages
,
mistral_model_config
,
mistral_tokenizer
,
content_format
=
"openai"
,
)
expected_conversation
=
[{
"role"
:
"system"
,
"content"
:
[{
"type"
:
"text"
,
"text"
:
"You are a helpful assistant."
},
{
"type"
:
"text"
,
"text"
:
"Only return the answer when you are confident."
}],
},
{
"role"
:
"user"
,
"content"
:
[{
"type"
:
"text"
,
"text"
:
"What is 2+2?"
}],
},
{
"role"
:
"assistant"
,
"content"
:
[
{
"type"
:
"text"
,
"text"
:
"Let me think about it."
},
{
"type"
:
"text"
,
"text"
:
"2+2 = 4"
},
{
"type"
:
"text"
,
"text"
:
"The answer is 4."
},
]
}]
assert
conversation_with_thinking
==
expected_conversation
def
test_apply_mistral_chat_template_thinking_chunk
():
# Moved import here to avoid yapf and isort conflicts
from
vllm.entrypoints.chat_utils
import
apply_mistral_chat_template
messages
=
[{
"role"
:
"system"
,
"content"
:
[{
"type"
:
"text"
,
"text"
:
"You are a helpful assistant."
},
{
"type"
:
"thinking"
,
"closed"
:
True
,
"thinking"
:
"Only return the answer when you are confident."
}]
},
{
"role"
:
"user"
,
"content"
:
"What is 2+2?"
},
{
"role"
:
"assistant"
,
"content"
:
[{
"type"
:
"text"
,
"text"
:
"Let me think about it."
},
{
"type"
:
"thinking"
,
"closed"
:
True
,
"thinking"
:
"2+2 = 4"
},
{
"type"
:
"text"
,
"text"
:
"The answer is 4."
,
}],
},
{
"role"
:
"user"
,
"content"
:
"Thanks, what is 3+3?"
}]
# TODO(Julien): upon model release change to a tokenizer already configured.
# =================================================================
mistral_tokenizer
=
MistralTokenizer
.
from_pretrained
(
"mistralai/Devstral-Small-2507"
)
assert
isinstance
(
mistral_tokenizer
.
tokenizer
,
Tekkenizer
)
# Add think special tokens to the tokenizer
mistral_tokenizer
.
tokenizer
.
_all_special_tokens
[
35
]
=
SpecialTokenInfo
(
rank
=
35
,
is_control
=
True
,
token_str
=
SpecialTokens
.
begin_think
.
value
)
mistral_tokenizer
.
tokenizer
.
_all_special_tokens
[
36
]
=
SpecialTokenInfo
(
rank
=
36
,
is_control
=
True
,
token_str
=
SpecialTokens
.
end_think
.
value
)
mistral_tokenizer
.
tokenizer
.
_special_tokens_reverse_vocab
=
{
k
:
v
for
k
,
v
in
mistral_tokenizer
.
tokenizer
.
_special_tokens_reverse_vocab
.
items
()
if
v
not
in
{
35
,
36
}
}
mistral_tokenizer
.
tokenizer
.
_special_tokens_reverse_vocab
[
SpecialTokens
.
begin_think
.
value
]
=
35
mistral_tokenizer
.
tokenizer
.
_special_tokens_reverse_vocab
[
SpecialTokens
.
end_think
.
value
]
=
36
mistral_tokenizer
.
instruct
.
BEGIN_THINK
=
35
mistral_tokenizer
.
instruct
.
END_THINK
=
36
# =================================================================
tokens_ids
=
apply_mistral_chat_template
(
mistral_tokenizer
,
messages
,
chat_template
=
None
,
tools
=
None
)
string_tokens
=
mistral_tokenizer
.
mistral
.
decode
(
tokens_ids
,
special_token_policy
=
SpecialTokenPolicy
.
KEEP
)
expected_tokens
=
(
r
"<s>[SYSTEM_PROMPT]You are a helpful assistant.[THINK]Only return the"
r
" answer when you are confident.[/THINK][/SYSTEM_PROMPT]"
r
"[INST]What is 2+2?[/INST]"
r
"Let me think about it.[THINK]2+2 = 4[/THINK]The answer is 4.</s>"
r
"[INST]Thanks, what is 3+3?[/INST]"
)
assert
string_tokens
==
expected_tokens
tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py
0 → 100644
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
import
flashinfer
import
pytest
import
torch
from
vllm.platforms
import
current_platform
if
not
current_platform
.
is_device_capability
(
100
):
pytest
.
skip
(
"This TRTLLM kernel requires NVIDIA Blackwell."
,
allow_module_level
=
True
)
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
# KV Cache Layout for TRT-LLM
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
NUM_HEADS
=
[(
64
,
8
),
(
16
,
16
),
(
40
,
8
),
(
32
,
8
)]
HEAD_SIZES
=
[
128
]
BLOCK_SIZES
=
[
16
,
32
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
NUM_BLOCKS
=
32768
# Large enough to test overflow in index calculation.
SOFT_CAPS
=
[
None
,
30.0
,
50.0
]
def
to_float8
(
x
,
dtype
=
torch
.
float8_e4m3fn
):
finfo
=
torch
.
finfo
(
dtype
)
min_val
,
max_val
=
x
.
aminmax
()
amax
=
torch
.
maximum
(
min_val
.
abs
(),
max_val
.
abs
()).
clamp
(
min
=
1e-12
)
scale
=
finfo
.
max
/
amax
*
0.1
x_scl_sat
=
(
x
*
scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
return
x_scl_sat
.
to
(
dtype
),
scale
.
float
().
reciprocal
()
@
pytest
.
mark
.
parametrize
(
"kv_lens"
,
[[
1328
,
18
,
463
],
[
1
,
54
,
293
,
70
]])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"kv_layout"
,
[
"HND"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
SOFT_CAPS
)
@
torch
.
inference_mode
def
test_flashinfer_trtllm_decode_with_baseline
(
kv_lens
:
list
[
int
],
num_heads
:
tuple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
kv_layout
:
str
,
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
0
)
num_seqs
=
len
(
kv_lens
)
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
max_kv_len
=
max
(
kv_lens
)
scale
=
head_size
**-
0.5
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
kv_cache_shape
=
None
if
kv_layout
==
"NHD"
:
kv_cache_shape
=
(
NUM_BLOCKS
,
2
,
block_size
,
num_kv_heads
,
head_size
)
elif
kv_layout
==
"HND"
:
kv_cache_shape
=
(
NUM_BLOCKS
,
2
,
num_kv_heads
,
block_size
,
head_size
)
else
:
raise
ValueError
(
f
"Invalid kv_layout:
{
kv_layout
}
"
)
key_value_cache
=
torch
.
randn
(
kv_cache_shape
,
dtype
=
dtype
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
NUM_BLOCKS
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
k_scale
=
v_scale
=
1.0
kv_indptr
=
[
0
]
kv_indices
=
[]
kv_last_page_lens
=
[]
for
i
in
range
(
num_seqs
):
seq_len
=
kv_lens
[
i
]
assert
seq_len
>
0
num_blocks
=
(
seq_len
+
block_size
-
1
)
//
block_size
kv_indices
.
extend
(
block_tables
[
i
,
:
num_blocks
])
kv_indptr
.
append
(
kv_indptr
[
-
1
]
+
num_blocks
)
kv_last_page_len
=
seq_len
%
block_size
if
kv_last_page_len
==
0
:
kv_last_page_len
=
block_size
kv_last_page_lens
.
append
(
kv_last_page_len
)
kv_indptr
=
torch
.
tensor
(
kv_indptr
,
dtype
=
torch
.
int32
)
kv_indices
=
torch
.
tensor
(
kv_indices
,
dtype
=
torch
.
int32
)
kv_last_page_lens
=
torch
.
tensor
(
kv_last_page_lens
,
dtype
=
torch
.
int32
)
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
)
wrapper
=
flashinfer
.
\
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
kv_layout
,
use_tensor_cores
=
(
(
num_query_heads
//
num_kv_heads
)
>
4
)
)
wrapper
.
plan
(
kv_indptr
,
kv_indices
,
kv_last_page_lens
,
num_query_heads
,
num_kv_heads
,
head_size
,
block_size
,
"NONE"
,
q_data_type
=
dtype
,
kv_data_type
=
dtype
,
logits_soft_cap
=
soft_cap
)
output
=
wrapper
.
run
(
query
,
key_value_cache
,
scale
)
# TRTLLM Decode
max_kv_len
=
max
(
kv_lens
)
kv_lens_tensor
=
torch
.
tensor
(
kv_lens
,
dtype
=
torch
.
int
,
device
=
query
.
device
)
output_trtllm
=
flashinfer
.
decode
.
trtllm_batch_decode_with_kv_cache
(
query
.
contiguous
(),
key_value_cache
,
workspace_buffer
,
num_query_heads
,
num_kv_heads
,
scale
,
block_tables
,
kv_lens_tensor
,
block_size
,
max_kv_len
,
"auto"
,
k_scale
,
v_scale
,
)
torch
.
testing
.
assert_close
(
output
,
output_trtllm
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
output_trtllm
))
}
"
tests/kernels/attention/test_rocm_attention_selector.py
View file @
711aa9d5
...
@@ -33,8 +33,12 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
...
@@ -33,8 +33,12 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
# change the attention backend to triton MLA
# change the attention backend to triton MLA
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
"TRITON_MLA"
)
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
"TRITON_MLA"
)
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
16
,
False
,
backend
=
get_attn_backend
(
576
,
False
,
True
)
torch
.
bfloat16
,
"auto"
,
16
,
False
,
use_mla
=
True
)
assert
(
backend
.
get_name
()
==
"TRITON_MLA"
assert
(
backend
.
get_name
()
==
"TRITON_MLA"
or
backend
.
get_name
()
==
"TRITON_MLA_VLLM_V1"
)
or
backend
.
get_name
()
==
"TRITON_MLA_VLLM_V1"
)
...
@@ -42,15 +46,23 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
...
@@ -42,15 +46,23 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
# If use_mla is true
# If use_mla is true
# The selected backend is triton MLA
# The selected backend is triton MLA
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
None
)
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
None
)
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
16
,
False
,
backend
=
get_attn_backend
(
576
,
False
,
True
)
torch
.
bfloat16
,
"auto"
,
16
,
False
,
use_mla
=
True
)
assert
(
backend
.
get_name
()
==
"TRITON_MLA"
assert
(
backend
.
get_name
()
==
"TRITON_MLA"
or
backend
.
get_name
()
==
"TRITON_MLA_VLLM_V1"
)
or
backend
.
get_name
()
==
"TRITON_MLA_VLLM_V1"
)
#
#
change the attention backend to AITER MLA
# change the attention backend to AITER MLA
# m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
# m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
# backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
# backend = get_attn_backend(576,
# False, True)
# torch.bfloat16,
# "auto",
# 1,
# False,
# use_mla=True)
# assert (backend.get_name() == "ROCM_AITER_MLA"
# assert (backend.get_name() == "ROCM_AITER_MLA"
# or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")
# or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")
...
@@ -60,7 +72,11 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
...
@@ -60,7 +72,11 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
# # The selected backend is ROCM_AITER_MLA
# # The selected backend is ROCM_AITER_MLA
# m.setenv(STR_BACKEND_ENV_VAR, None)
# m.setenv(STR_BACKEND_ENV_VAR, None)
# m.setenv("VLLM_ROCM_USE_AITER", "1")
# m.setenv("VLLM_ROCM_USE_AITER", "1")
# backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
# backend = get_attn_backend(576,
# False, True)
# torch.bfloat16,
# "auto",
# 1,
# False,
# use_mla=True)
# assert (backend.get_name() == "ROCM_AITER_MLA"
# assert (backend.get_name() == "ROCM_AITER_MLA"
# or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")
# or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")
\ No newline at end of file
tests/kernels/attention/untest_flashinfer.py
View file @
711aa9d5
...
@@ -77,6 +77,7 @@ def ref_paged_attn(
...
@@ -77,6 +77,7 @@ def ref_paged_attn(
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
30.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
30.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
,
64
])
@
torch
.
inference_mode
@
torch
.
inference_mode
def
test_flashinfer_decode_with_paged_kv
(
def
test_flashinfer_decode_with_paged_kv
(
kv_lens
:
list
[
int
],
kv_lens
:
list
[
int
],
...
@@ -85,6 +86,7 @@ def test_flashinfer_decode_with_paged_kv(
...
@@ -85,6 +86,7 @@ def test_flashinfer_decode_with_paged_kv(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
soft_cap
:
Optional
[
float
],
sliding_window
:
Optional
[
int
],
)
->
None
:
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
0
)
current_platform
.
seed_everything
(
0
)
...
@@ -136,17 +138,20 @@ def test_flashinfer_decode_with_paged_kv(
...
@@ -136,17 +138,20 @@ def test_flashinfer_decode_with_paged_kv(
use_tensor_cores
=
(
use_tensor_cores
=
(
(
num_query_heads
//
num_kv_heads
)
>
4
)
(
num_query_heads
//
num_kv_heads
)
>
4
)
)
)
wrapper
.
plan
(
kv_indptr
,
wrapper
.
plan
(
kv_indices
,
kv_indptr
,
kv_last_page_lens
,
kv_indices
,
num_query_heads
,
kv_last_page_lens
,
num_kv_heads
,
num_query_heads
,
head_size
,
num_kv_heads
,
block_size
,
head_size
,
"NONE"
,
block_size
,
q_data_type
=
dtype
,
"NONE"
,
kv_data_type
=
dtype
,
window_left
=
sliding_window
-
1
if
sliding_window
is
not
None
else
-
1
,
logits_soft_cap
=
soft_cap
)
q_data_type
=
dtype
,
kv_data_type
=
dtype
,
logits_soft_cap
=
soft_cap
,
)
output
=
wrapper
.
run
(
query
,
key_value_cache
)
output
=
wrapper
.
run
(
query
,
key_value_cache
)
...
@@ -157,7 +162,8 @@ def test_flashinfer_decode_with_paged_kv(
...
@@ -157,7 +162,8 @@ def test_flashinfer_decode_with_paged_kv(
kv_lens
=
kv_lens
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
scale
=
scale
,
scale
=
scale
,
soft_cap
=
soft_cap
)
soft_cap
=
soft_cap
,
sliding_window
=
sliding_window
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
...
@@ -168,12 +174,17 @@ def test_flashinfer_decode_with_paged_kv(
...
@@ -168,12 +174,17 @@ def test_flashinfer_decode_with_paged_kv(
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
30.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
30.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
,
64
])
@
torch
.
inference_mode
@
torch
.
inference_mode
def
test_flashinfer_prefill_with_paged_kv
(
seq_lens
:
list
[
tuple
[
int
,
int
]],
def
test_flashinfer_prefill_with_paged_kv
(
num_heads
:
tuple
[
int
,
int
],
seq_lens
:
list
[
tuple
[
int
,
int
]],
head_size
:
int
,
dtype
:
torch
.
dtype
,
num_heads
:
tuple
[
int
,
int
],
block_size
:
int
,
head_size
:
int
,
soft_cap
:
Optional
[
float
])
->
None
:
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
sliding_window
:
Optional
[
int
],
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
0
)
current_platform
.
seed_everything
(
0
)
num_seqs
=
len
(
seq_lens
)
num_seqs
=
len
(
seq_lens
)
...
@@ -242,6 +253,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
...
@@ -242,6 +253,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
num_kv_heads
,
num_kv_heads
,
head_size
,
head_size
,
block_size
,
block_size
,
window_left
=
sliding_window
-
1
if
sliding_window
is
not
None
else
-
1
,
q_data_type
=
dtype
,
q_data_type
=
dtype
,
kv_data_type
=
dtype
,
kv_data_type
=
dtype
,
logits_soft_cap
=
soft_cap
,
logits_soft_cap
=
soft_cap
,
...
@@ -259,7 +271,8 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
...
@@ -259,7 +271,8 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
kv_lens
=
kv_lens
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
scale
=
scale
,
scale
=
scale
,
soft_cap
=
soft_cap
)
soft_cap
=
soft_cap
,
sliding_window
=
sliding_window
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
5e-2
,
rtol
=
1e-2
),
\
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
5e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
...
...
tests/kernels/core/test_layernorm.py
View file @
711aa9d5
...
@@ -26,6 +26,7 @@ CUDA_DEVICES = [
...
@@ -26,6 +26,7 @@ CUDA_DEVICES = [
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"strided_input"
,
[
False
,
True
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_rms_norm
(
def
test_rms_norm
(
num_tokens
:
int
,
num_tokens
:
int
,
...
@@ -34,13 +35,17 @@ def test_rms_norm(
...
@@ -34,13 +35,17 @@ def test_rms_norm(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
strided_input
:
bool
,
)
->
None
:
)
->
None
:
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
layer
=
RMSNorm
(
hidden_size
).
to
(
dtype
=
dtype
)
layer
=
RMSNorm
(
hidden_size
).
to
(
dtype
=
dtype
)
layer
.
weight
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
layer
.
weight
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
scale
=
1
/
(
2
*
hidden_size
)
scale
=
1
/
(
2
*
hidden_size
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
last_dim
=
2
*
hidden_size
if
strided_input
else
hidden_size
x
=
torch
.
randn
(
num_tokens
,
last_dim
,
dtype
=
dtype
)
x
=
x
[...,
:
hidden_size
]
assert
x
.
is_contiguous
()
!=
strided_input
x
*=
scale
x
*=
scale
residual
=
torch
.
randn_like
(
x
)
*
scale
if
add_residual
else
None
residual
=
torch
.
randn_like
(
x
)
*
scale
if
add_residual
else
None
...
@@ -63,7 +68,7 @@ def test_rms_norm(
...
@@ -63,7 +68,7 @@ def test_rms_norm(
else
:
else
:
opcheck
(
torch
.
ops
.
_C
.
rms_norm
,
opcheck
(
torch
.
ops
.
_C
.
rms_norm
,
(
out
,
x
,
layer
.
weight
.
data
,
layer
.
variance_epsilon
))
(
out
,
x
,
layer
.
weight
.
data
,
layer
.
variance_epsilon
))
# @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
# @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
# @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
# @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
...
@@ -72,6 +77,7 @@ def test_rms_norm(
...
@@ -72,6 +77,7 @@ def test_rms_norm(
# @pytest.mark.parametrize("quant_scale", [1.0, 0.01, 10.0])
# @pytest.mark.parametrize("quant_scale", [1.0, 0.01, 10.0])
# @pytest.mark.parametrize("seed", SEEDS)
# @pytest.mark.parametrize("seed", SEEDS)
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @pytest.mark.parametrize("strided_input", [False, True])
# def test_fused_rms_norm_quant(
# def test_fused_rms_norm_quant(
# num_tokens: int,
# num_tokens: int,
# hidden_size: int,
# hidden_size: int,
...
@@ -80,13 +86,18 @@ def test_rms_norm(
...
@@ -80,13 +86,18 @@ def test_rms_norm(
# quant_scale: float,
# quant_scale: float,
# seed: int,
# seed: int,
# device: str,
# device: str,
# strided_input: bool,
# ) -> None:
# ) -> None:
# current_platform.seed_everything(seed)
# current_platform.seed_everything(seed)
# torch.set_default_device(device)
# torch.set_default_device(device)
# weight = torch.empty(hidden_size, dtype=dtype).normal_(mean=1.0, std=0.1)
# weight = torch.empty(hidden_size, dtype=dtype).normal_(mean=1.0, std=0.1)
# scale = 1 / (2 * hidden_size)
# scale = 1 / (2 * hidden_size)
# x = torch.randn(num_tokens, hidden_size, dtype=dtype)
# last_dim = 2 * hidden_size if strided_input else hidden_size
# x_base = torch.randn(num_tokens, last_dim, dtype=dtype)
# x = x_base[..., :hidden_size]
# assert x.is_contiguous() != strided_input
# x *= scale
# x *= scale
# if add_residual:
# if add_residual:
# residual = torch.randn_like(x) * scale
# residual = torch.randn_like(x) * scale
...
@@ -106,9 +117,11 @@ def test_rms_norm(
...
@@ -106,9 +117,11 @@ def test_rms_norm(
# # Unfused kernel is in-place so it goes second
# # Unfused kernel is in-place so it goes second
# # Also use a separate clone of x to avoid modifying the input
# # Also use a separate clone of x to avoid modifying the input
# x_unfused = x.clone()
# x_unfused_base = x_base.clone()
# x_unfused = x_unfused_base[..., :hidden_size]
# assert x_unfused.is_contiguous() != strided_input
# torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6)
# torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6)
# torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused,
# torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused
.contiguous()
,
# quant_scale_t)
# quant_scale_t)
# torch.cuda.synchronize()
# torch.cuda.synchronize()
...
@@ -116,7 +129,6 @@ def test_rms_norm(
...
@@ -116,7 +129,6 @@ def test_rms_norm(
# residual,
# residual,
# atol=1e-2,
# atol=1e-2,
# rtol=1e-2)
# rtol=1e-2)
# opcheck(
# opcheck(
# torch.ops._C.fused_add_rms_norm_static_fp8_quant,
# torch.ops._C.fused_add_rms_norm_static_fp8_quant,
# (out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6))
# (out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6))
...
@@ -131,7 +143,7 @@ def test_rms_norm(
...
@@ -131,7 +143,7 @@ def test_rms_norm(
# opcheck(torch.ops._C.rms_norm_static_fp8_quant,
# opcheck(torch.ops._C.rms_norm_static_fp8_quant,
# (out_quant_fused, x, weight, quant_scale_t, 1e-6))
# (out_quant_fused, x, weight, quant_scale_t, 1e-6))
# torch.testing.assert_close(out_quant
_fused
.to(dtype=torch.float32),
# torch.testing.assert_close(out_quant.to(dtype=torch.float32),
# out_quant.to(dtype=torch.float32),
# out_quant
_fused
.to(dtype=torch.float32),
# atol=1e-3,
# atol=1e-3,
# rtol=1e-3)
# rtol=1e-3)
tests/kernels/mamba/test_mamba_mixer2.py
View file @
711aa9d5
...
@@ -119,7 +119,8 @@ def mixer2_gated_norm_tensor_parallel(
...
@@ -119,7 +119,8 @@ def mixer2_gated_norm_tensor_parallel(
gate_states
[...,
local_rank
*
N
:(
local_rank
+
1
)
*
N
],
gate_states
[...,
local_rank
*
N
:(
local_rank
+
1
)
*
N
],
)
)
ref_output
=
mixer_single_gpu
(
hidden_states
,
gate_states
)
ref_output
=
mixer_single_gpu
(
hidden_states
,
gate_states
)
torch
.
allclose
(
output
,
torch
.
testing
.
assert_close
(
output
,
ref_output
[...,
local_rank
*
N
:(
local_rank
+
1
)
*
N
],
ref_output
[...,
atol
=
1e-3
,
local_rank
*
N
:(
local_rank
+
1
)
*
N
],
rtol
=
1e-3
)
atol
=
5e-3
,
rtol
=
1e-3
)
tests/kernels/mamba/test_mamba_ssm_ssd.py
View file @
711aa9d5
...
@@ -6,11 +6,11 @@ import torch
...
@@ -6,11 +6,11 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
(
_query_start_loc_to_chunk_indices_offsets
)
from
vllm.model_executor.layers.mamba.ops.ssd_combined
import
(
from
vllm.model_executor.layers.mamba.ops.ssd_combined
import
(
mamba_chunk_scan_combined
)
mamba_chunk_scan_combined
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.mamba_attn
import
(
_query_start_loc_to_chunk_indices_offsets
)
# Added by the IBM Team, 2024
# Added by the IBM Team, 2024
...
@@ -193,6 +193,13 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
...
@@ -193,6 +193,13 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
# this tests the kernels on a single example (no batching)
# this tests the kernels on a single example (no batching)
# TODO: the bfloat16 case requires higher thresholds. To be investigated
if
itype
==
torch
.
bfloat16
:
atol
,
rtol
=
5e-2
,
5e-2
else
:
atol
,
rtol
=
8e-3
,
5e-3
# set seed
# set seed
batch_size
=
1
# batch_size
batch_size
=
1
# batch_size
# ssd_minimal_discrete requires chunk_size divide seqlen
# ssd_minimal_discrete requires chunk_size divide seqlen
...
@@ -216,14 +223,14 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
...
@@ -216,14 +223,14 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
return_final_states
=
True
)
return_final_states
=
True
)
# just test the last in sequence
# just test the last in sequence
torch
.
all
close
(
Y
[:,
-
1
],
Y_min
[:,
-
1
],
atol
=
1e-3
,
rtol
=
1e-3
)
torch
.
testing
.
assert_
close
(
Y
[:,
-
1
],
Y_min
[:,
-
1
],
atol
=
atol
,
rtol
=
rtol
)
# just test the last head
# just test the last head
# NOTE, in the kernel we always cast states to fp32
# NOTE, in the kernel we always cast states to fp32
torch
.
all
close
(
final_state
[:,
-
1
],
torch
.
testing
.
assert_
close
(
final_state
[:,
-
1
],
final_state_min
[:,
-
1
].
to
(
torch
.
float32
),
final_state_min
[:,
-
1
].
to
(
torch
.
float32
),
atol
=
1e-3
,
atol
=
atol
,
rtol
=
1e-3
)
rtol
=
rtol
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
float16
])
...
@@ -263,6 +270,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
...
@@ -263,6 +270,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
seqlen
,
chunk_size
,
num_examples
,
cases
=
seq_len_chunk_size_cases
seqlen
,
chunk_size
,
num_examples
,
cases
=
seq_len_chunk_size_cases
# TODO: the irregular chunk size cases have some issues and require higher
# tolerance. This is to be invesigated
if
chunk_size
not
in
{
8
,
256
}:
atol
,
rtol
=
5e-1
,
5e-1
else
:
atol
,
rtol
=
5e-3
,
5e-3
# hold state during the cutting process so we know if an
# hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle
# example has been exhausted and needs to cycle
last_taken
:
dict
=
{}
# map: eg -> pointer to last taken sample
last_taken
:
dict
=
{}
# map: eg -> pointer to last taken sample
...
@@ -300,7 +314,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
...
@@ -300,7 +314,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
# just test one dim and dstate
# just test one dim and dstate
Y_eg
=
Y
[
0
,
cu_seqlens
[
i
]:
cu_seqlens
[
i
+
1
],
0
,
0
]
Y_eg
=
Y
[
0
,
cu_seqlens
[
i
]:
cu_seqlens
[
i
+
1
],
0
,
0
]
Y_min_eg
=
Y_min
[
i
][:,
0
,
0
]
Y_min_eg
=
Y_min
[
i
][:,
0
,
0
]
torch
.
all
close
(
Y_eg
,
Y_min_eg
,
atol
=
1e-3
,
rtol
=
1e-3
)
torch
.
testing
.
assert_
close
(
Y_eg
,
Y_min_eg
,
atol
=
atol
,
rtol
=
rtol
)
# update states
# update states
states
=
new_states
states
=
new_states
...
...
tests/kernels/mamba/untest_causal_conv1d.py
View file @
711aa9d5
...
@@ -6,9 +6,8 @@ from typing import Optional
...
@@ -6,9 +6,8 @@ from typing import Optional
import
pytest
import
pytest
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
# noqa: F401
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_update
)
causal_conv1d_fn
,
causal_conv1d_update
)
...
@@ -144,79 +143,6 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
...
@@ -144,79 +143,6 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
x
=
x
.
contiguous
()
x
=
x
.
contiguous
()
bias
=
bias
.
contiguous
()
if
bias
is
not
None
else
None
bias
=
bias
.
contiguous
()
if
bias
is
not
None
else
None
opcheck
(
torch
.
ops
.
_C
.
causal_conv1d_fwd
,
(
x
,
weight
,
bias
,
conv_states
,
cu_seq_len
,
cache_indices
,
has_initial_state
,
activation
in
[
"silu"
,
"swish"
],
pad_slot_id
))
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
,
torch
.
float
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_initial_state"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
1
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
784
,
1024
,
1025
,
2048
,
4096
])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
64
])
@
pytest
.
mark
.
parametrize
(
'batch'
,
[
1
])
def
test_causal_conv1d
(
batch
,
dim
,
seqlen
,
width
,
has_bias
,
silu_activation
,
has_initial_state
,
itype
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
# set seed
current_platform
.
seed_everything
(
0
)
x
=
torch
.
randn
(
batch
,
dim
,
seqlen
,
device
=
device
,
dtype
=
itype
).
contiguous
()
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
if
has_initial_state
:
initial_states
=
torch
.
randn
(
batch
,
dim
,
width
-
1
,
device
=
device
,
dtype
=
itype
)
has_initial_state_tensor
=
torch
.
ones
(
batch
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
else
:
initial_states
=
None
has_initial_state_tensor
=
None
x_ref
=
x
.
clone
()
weight_ref
=
weight
.
clone
()
bias_ref
=
bias
.
clone
()
if
bias
is
not
None
else
None
initial_states_ref
=
initial_states
.
clone
(
)
if
initial_states
is
not
None
else
None
activation
=
None
if
not
silu_activation
else
"silu"
out
=
causal_conv1d_fn
(
x
,
weight
,
bias
,
activation
=
activation
,
conv_states
=
initial_states
,
has_initial_state
=
has_initial_state_tensor
)
out_ref
,
final_states_ref
=
causal_conv1d_ref
(
x_ref
,
weight_ref
,
bias_ref
,
initial_states
=
initial_states_ref
,
return_final_states
=
True
,
activation
=
activation
)
if
has_initial_state
:
assert
initial_states
is
not
None
and
final_states_ref
is
not
None
assert
torch
.
allclose
(
initial_states
,
final_states_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
causal_conv1d_opcheck_fn
(
x
,
weight
,
bias
,
activation
=
activation
,
conv_states
=
initial_states
,
has_initial_state
=
has_initial_state_tensor
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
False
,
True
])
...
@@ -255,22 +181,19 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
...
@@ -255,22 +181,19 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
assert
torch
.
equal
(
conv_state
,
conv_state_ref
)
assert
torch
.
equal
(
conv_state
,
conv_state_ref
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
opcheck
(
torch
.
ops
.
_C
.
causal_conv1d_update
,
(
x
,
conv_state
,
weight
,
bias
,
activation
in
[
"silu"
,
"swish"
],
None
,
None
,
PAD_SLOT_ID
))
@
pytest
.
mark
.
parametrize
(
"itype"
,
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
1
,
4
,
5
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
1
,
3
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
2
,
3
,
4
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
3
,
4
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
2048
+
16
,
4096
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
+
16
,
4096
])
# tests correctness in case subset of the sequences are padded
# tests correctness in case subset of the sequences are padded
@
pytest
.
mark
.
parametrize
(
"with_padding"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"with_padding"
,
[
True
,
False
])
def
test_causal_conv1d_update_with_batch_gather
(
with_padding
,
dim
,
width
,
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
3
])
seqlen
,
has_bias
,
def
test_causal_conv1d_update_with_batch_gather
(
batch_size
,
with_padding
,
dim
,
width
,
seqlen
,
has_bias
,
silu_activation
,
itype
):
silu_activation
,
itype
):
device
=
"cuda"
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
...
@@ -280,12 +203,15 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
...
@@ -280,12 +203,15 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
# set seed
# set seed
current_platform
.
seed_everything
(
0
)
current_platform
.
seed_everything
(
0
)
batch_size
=
3
padding
=
5
if
with_padding
else
0
padding
=
5
if
with_padding
else
0
padded_batch_size
=
batch_size
+
padding
padded_batch_size
=
batch_size
+
padding
# total_entries = number of cache line
total_entries
=
10
*
batch_size
total_entries
=
10
*
batch_size
x
=
torch
.
randn
(
padded_batch_size
,
dim
,
1
,
device
=
device
,
dtype
=
itype
)
# x will be (batch, dim, seqlen) with contiguous along dim-axis
x
=
torch
.
randn
(
padded_batch_size
,
seqlen
,
dim
,
device
=
device
,
dtype
=
itype
).
transpose
(
1
,
2
)
x_ref
=
x
.
clone
()
x_ref
=
x
.
clone
()
conv_state_indices
=
torch
.
randperm
(
total_entries
)[:
batch_size
].
to
(
conv_state_indices
=
torch
.
randperm
(
total_entries
)[:
batch_size
].
to
(
...
@@ -300,17 +226,22 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
...
@@ -300,17 +226,22 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
[
PAD_SLOT_ID
]
*
padding
,
dtype
=
torch
.
int32
,
device
=
device
)
[
PAD_SLOT_ID
]
*
padding
,
dtype
=
torch
.
int32
,
device
=
device
)
],
],
dim
=
0
)
dim
=
0
)
# conv_state will be (cache_lines, dim, state_len)
# with contiguous along dim-axis
conv_state
=
torch
.
randn
(
total_entries
,
conv_state
=
torch
.
randn
(
total_entries
,
dim
,
width
-
1
,
width
-
1
,
dim
,
device
=
device
,
device
=
device
,
dtype
=
itype
)
dtype
=
itype
).
transpose
(
1
,
2
)
conv_state_for_padding_test
=
conv_state
.
clone
()
conv_state_for_padding_test
=
conv_state
.
clone
()
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
conv_state_ref
=
conv_state
[
conv_state_indices
,
:].
detach
().
clone
()
conv_state_ref
=
conv_state
[
conv_state_indices
,
:].
detach
().
clone
()
activation
=
None
if
not
silu_activation
else
"silu"
activation
=
None
if
not
silu_activation
else
"silu"
out
=
causal_conv1d_update
(
x
,
out
=
causal_conv1d_update
(
x
,
conv_state
,
conv_state
,
weight
,
weight
,
...
@@ -325,26 +256,21 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
...
@@ -325,26 +256,21 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
activation
=
activation
)
activation
=
activation
)
assert
torch
.
equal
(
conv_state
[
conv_state_indices
,
:],
conv_state_ref
)
assert
torch
.
equal
(
conv_state
[
conv_state_indices
,
:],
conv_state_ref
)
assert
torch
.
allclose
(
out
[:
batch_size
],
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
equal
(
conv_state
[
unused_states_bool
],
assert
torch
.
equal
(
conv_state
[
unused_states_bool
],
conv_state_for_padding_test
[
unused_states_bool
])
conv_state_for_padding_test
[
unused_states_bool
])
assert
torch
.
allclose
(
out
[:
batch_size
],
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
opcheck
(
torch
.
ops
.
_C
.
causal_conv1d_update
,
(
x
,
conv_state
,
weight
,
bias
,
activation
in
[
"silu"
,
"swish"
],
None
,
padded_state_indices
,
PAD_SLOT_ID
))
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
8
,
30
,
249
,
2049
,
4096
])
'seqlen'
,
[
8
,
16
,
32
,
64
,
128
,
256
,
512
,
784
,
1024
,
2048
,
2049
,
4096
])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
64
,
4096
])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
64
,
4096
])
# tests correctness in case subset of the sequences are padded
@
pytest
.
mark
.
parametrize
(
'with_padding'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'with_padding'
,
[
True
,
False
])
def
test_causal_conv1d_varlen
(
with_padding
,
dim
,
seqlen
,
width
,
has_bias
,
@
pytest
.
mark
.
parametrize
(
'batch'
,
[
4
,
10
])
silu_activation
,
itype
):
def
test_causal_conv1d_varlen
(
batch
,
with_padding
,
dim
,
seqlen
,
width
,
has_bias
,
silu_activation
,
itype
):
device
=
"cuda"
device
=
"cuda"
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
...
@@ -353,14 +279,13 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
...
@@ -353,14 +279,13 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
# set seed
# set seed
current_platform
.
seed_everything
(
0
)
current_platform
.
seed_everything
(
0
)
seqlens
=
[]
seqlens
=
[]
batch_size
=
4
batch_size
=
batch
if
seqlen
<
10
:
batch_size
=
1
padding
=
3
if
with_padding
else
0
padding
=
3
if
with_padding
else
0
padded_batch_size
=
batch_size
+
padding
padded_batch_size
=
batch_size
+
padding
nsplits
=
padded_batch_size
-
1
nsplits
=
padded_batch_size
-
1
eos_pos
=
torch
.
randperm
(
seqlen
-
1
)[:
nsplits
].
sort
().
values
eos_pos
=
torch
.
randperm
(
seqlen
-
1
)[:
nsplits
].
sort
().
values
seqlens
.
append
(
seqlens
.
append
(
torch
.
diff
(
torch
.
diff
(
torch
.
cat
(
torch
.
cat
(
...
@@ -373,19 +298,22 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
...
@@ -373,19 +298,22 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
cumsum
=
torch
.
cumsum
(
torch
.
tensor
(
seqlens
[
0
]),
dim
=
0
).
to
(
torch
.
int32
)
cumsum
=
torch
.
cumsum
(
torch
.
tensor
(
seqlens
[
0
]),
dim
=
0
).
to
(
torch
.
int32
)
cumsum
=
torch
.
concat
([
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
),
cumsum
],
cumsum
=
torch
.
concat
([
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
),
cumsum
],
dim
=
0
)
dim
=
0
)
x
=
torch
.
randn
(
1
,
4096
+
dim
+
64
,
seqlen
,
device
=
device
,
x
=
rearrange
(
dtype
=
itype
)[:,
4096
:
4096
+
dim
,
:]
torch
.
randn
(
1
,
seqlen
,
4096
+
dim
+
64
,
device
=
device
,
dtype
=
itype
),
"b s d -> b d s"
)[:,
4096
:
4096
+
dim
,
:]
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
x_ref
=
x
.
clone
()
x_ref
=
x
.
clone
()
weight_ref
=
weight
.
clone
()
weight_ref
=
weight
.
clone
()
bias_ref
=
bias
.
clone
()
if
bias
is
not
None
else
None
bias_ref
=
bias
.
clone
()
if
bias
is
not
None
else
None
activation
=
None
if
not
silu_activation
else
"silu"
activation
=
None
if
not
silu_activation
else
"silu"
final_states
=
torch
.
randn
(
total_entries
,
final_states
=
torch
.
randn
(
total_entries
,
dim
,
width
-
1
,
width
-
1
,
dim
,
device
=
x
.
device
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
dtype
=
x
.
dtype
)
.
transpose
(
1
,
2
)
final_states_ref
=
final_states
.
clone
()
final_states_ref
=
final_states
.
clone
()
has_initial_states
=
torch
.
randint
(
0
,
has_initial_states
=
torch
.
randint
(
0
,
2
,
(
cumsum
.
shape
[
0
]
-
1
,
),
2
,
(
cumsum
.
shape
[
0
]
-
1
,
),
...
@@ -400,10 +328,16 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
...
@@ -400,10 +328,16 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
[
PAD_SLOT_ID
]
*
padding
,
dtype
=
torch
.
int32
,
device
=
device
),
[
PAD_SLOT_ID
]
*
padding
,
dtype
=
torch
.
int32
,
device
=
device
),
],
],
dim
=-
1
)
dim
=-
1
)
out
=
causal_conv1d_fn
(
x
.
squeeze
(
0
),
weight
,
bias
=
bias
,
conv_states
=
final_states
,
query_start_loc
=
cumsum
.
cuda
(),
cache_indices
=
padded_state_indices
,
has_initial_state
=
has_initial_states
,
activation
=
activation
,
pad_slot_id
=
PAD_SLOT_ID
)
out
=
causal_conv1d_fn
(
x
.
squeeze
(
0
),
weight
,
bias
,
cumsum
.
cuda
(),
padded_state_indices
,
has_initial_states
,
final_states
,
activation
,
PAD_SLOT_ID
)
out_ref
=
[]
out_ref
=
[]
out_ref_b
=
[]
out_ref_b
=
[]
...
@@ -426,13 +360,9 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
...
@@ -426,13 +360,9 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
out_ref
.
append
(
torch
.
cat
([
t
[
0
]
for
t
in
out_ref_b
],
dim
=
2
))
out_ref
.
append
(
torch
.
cat
([
t
[
0
]
for
t
in
out_ref_b
],
dim
=
2
))
out_ref_tensor
=
torch
.
cat
(
out_ref
,
dim
=
0
)
out_ref_tensor
=
torch
.
cat
(
out_ref
,
dim
=
0
)
unpadded_out
=
out
[:,
:
out_ref_tensor
.
shape
[
-
1
]]
assert
torch
.
allclose
(
unpadded_out
,
out_ref_tensor
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
final_states
[
state_indices
],
assert
torch
.
allclose
(
final_states
[
state_indices
],
final_states_ref
[
state_indices
],
final_states_ref
[
state_indices
],
rtol
=
rtol
,
rtol
=
rtol
,
atol
=
atol
)
atol
=
atol
)
unpadded_out
=
out
[:,
:
out_ref_tensor
.
shape
[
-
1
]]
causal_conv1d_opcheck_fn
(
x
.
squeeze
(
0
),
weight
,
bias
,
cumsum
.
cuda
(),
assert
torch
.
allclose
(
unpadded_out
,
out_ref_tensor
,
rtol
=
rtol
,
atol
=
atol
)
padded_state_indices
,
has_initial_states
,
final_states
,
activation
)
\ No newline at end of file
tests/
spec_decode
/__init__.py
→
tests/
kernels/moe/modular_kernel_tools
/__init__.py
View file @
711aa9d5
File moved
tests/kernels/moe/modular_kernel_tools/cli_args.py
0 → 100644
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
.common
import
Config
from
.mk_objects
import
(
MK_ALL_PREPARE_FINALIZE_TYPES
,
MK_FUSED_EXPERT_TYPES
,
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
)
def
make_config_arg_parser
(
description
:
str
):
def
to_pf_class_type
(
s
:
str
)
->
mk
.
FusedMoEPrepareAndFinalize
:
for
pf
in
MK_ALL_PREPARE_FINALIZE_TYPES
:
if
pf
.
__name__
==
s
:
return
pf
raise
ValueError
(
f
"Cannot find a PrepareFinalize type that matches
{
s
}
"
)
def
to_experts_class_type
(
s
:
str
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
for
fe
in
MK_FUSED_EXPERT_TYPES
:
if
fe
.
__name__
==
s
:
return
fe
raise
ValueError
(
f
"Cannot find a FusedExperts type that matches
{
s
}
"
)
def
to_quant_torch_dtype
(
s
:
str
)
->
torch
.
dtype
:
if
s
==
"torch.float8_e4m3fn"
:
return
torch
.
float8_e4m3fn
raise
ValueError
(
f
"Unsupported quant type
{
s
}
"
)
parser
=
argparse
.
ArgumentParser
(
description
=
description
)
parser
.
add_argument
(
"--world-size"
,
type
=
int
,
default
=
2
,
help
=
"Number of ranks that participate in all2all"
,
)
parser
.
add_argument
(
"--pf-type"
,
type
=
to_pf_class_type
,
required
=
True
,
help
=
(
"Choose a PrepareFinalize Type : "
f
"
{
[
x
.
__name__
for
x
in
MK_ALL_PREPARE_FINALIZE_TYPES
]
}
"
),
)
parser
.
add_argument
(
"--experts-type"
,
type
=
to_experts_class_type
,
required
=
True
,
help
=
(
f
"Choose a FusedExpert type : "
f
"
{
[
x
.
__name__
for
x
in
MK_FUSED_EXPERT_TYPES
]
}
"
),
)
parser
.
add_argument
(
"-m"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
64
],
help
=
"num tokens per rank"
,
)
parser
.
add_argument
(
"-k"
,
type
=
int
,
default
=
7168
,
help
=
"hidden-size"
,
)
parser
.
add_argument
(
"-n"
,
type
=
int
,
default
=
1024
,
help
=
"N dimension of the first fused-moe matmul"
,
)
parser
.
add_argument
(
"--num-experts"
,
type
=
int
,
default
=
32
,
help
=
"Global num experts"
)
parser
.
add_argument
(
"--topk"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
4
,
1
],
help
=
"num topk"
)
parser
.
add_argument
(
"--fused-moe-chunk-size"
,
type
=
int
,
help
=
"Fused moe chunk size used for the non-batched fused experts impl."
)
# Quant args
parser
.
add_argument
(
"--quant-dtype"
,
type
=
to_quant_torch_dtype
,
help
=
"Quant datatype"
)
parser
.
add_argument
(
"--per-token-quantized-activations"
,
action
=
'store_true'
,
help
=
(
"The input activations must be per-token "
"quantized"
))
parser
.
add_argument
(
"--per-channel-quantized-weights"
,
action
=
"store_true"
,
help
=
"The weights must be per-channel quantized."
)
parser
.
add_argument
(
"--block-shape"
,
nargs
=
"+"
,
type
=
int
,
help
=
"Quantization block shape"
)
# Torch trace profile generation args
parser
.
add_argument
(
"--torch-trace-dir-path"
,
type
=
str
,
default
=
None
,
help
=
"Get torch trace for single execution"
)
return
parser
def
_validate_args
(
args
:
argparse
.
Namespace
):
if
args
.
quant_dtype
is
not
None
:
assert
args
.
quant_dtype
==
torch
.
float8_e4m3fn
if
args
.
block_shape
is
not
None
:
assert
len
(
args
.
block_shape
)
==
2
,
(
f
"block shape must have 2 elements. got
{
args
.
block_shape
}
"
)
if
args
.
experts_type
in
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
:
assert
args
.
world_size
==
1
,
(
"Single GPU objects need world size set to 1"
)
if
args
.
torch_trace_dir_path
is
not
None
:
from
pathlib
import
Path
assert
Path
(
args
.
torch_trace_dir_path
).
is_dir
(),
(
f
"Please create
{
args
.
torch_trace_dir_path
}
"
)
def
make_config
(
args
:
argparse
.
Namespace
)
->
Config
:
_validate_args
(
args
)
quant_config
=
None
if
args
.
quant_dtype
is
not
None
:
quant_config
=
FusedMoEQuantConfig
(
quant_dtype
=
args
.
quant_dtype
,
per_act_token_quant
=
args
.
per_token_quantized_activations
,
per_out_ch_quant
=
args
.
per_channel_quantized_weights
,
block_shape
=
args
.
block_shape
)
return
Config
(
Ms
=
args
.
m
,
K
=
args
.
k
,
N
=
args
.
n
,
E
=
args
.
num_experts
,
topks
=
args
.
topk
,
dtype
=
torch
.
bfloat16
,
# hard-code
quant_config
=
quant_config
,
prepare_finalize_type
=
args
.
pf_type
,
fused_experts_type
=
args
.
experts_type
,
fused_moe_chunk_size
=
args
.
fused_moe_chunk_size
,
world_size
=
args
.
world_size
,
torch_trace_dir_path
=
args
.
torch_trace_dir_path
)
tests/kernels/moe/modular_kernel_tools/common.py
0 → 100644
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
Any
,
Optional
,
Union
import
torch
import
vllm._custom_ops
as
ops
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
tests.kernels.utils
import
torch_experts
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_dp_group
,
get_tensor_model_parallel_world_size
# Fused experts and PrepareFinalize imports
from
vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe
import
(
BatchedDeepGemmExperts
)
from
vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe
import
(
# noqa: E501
BatchedTritonOrDeepGemmExperts
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
)
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
CutlassExpertsFp8
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
DeepGemmExperts
from
vllm.model_executor.layers.fused_moe.fused_batched_moe
import
(
BatchedTritonExperts
,
NaiveBatchedExperts
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoEMethodBase
,
TritonExperts
)
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
)
from
vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe
import
(
TritonOrDeepGemmExperts
)
from
vllm.utils
import
has_deep_ep
,
has_deep_gemm
,
has_pplx
from
.parallel_utils
import
ProcessGroupInfo
from
.utils
import
(
make_block_quant_fp8_weights
,
make_non_quant_weights
,
make_quant_fp8_weights
,
per_token_cast_to_fp8
)
if
has_pplx
():
from
vllm.model_executor.layers.fused_moe.pplx_prepare_finalize
import
(
PplxPrepareAndFinalize
)
if
has_deep_ep
():
from
vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize
import
(
# noqa: E501
DeepEPHTPrepareAndFinalize
)
from
vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize
import
(
# noqa: E501
DeepEPLLPrepareAndFinalize
)
def
_describe_tensor
(
t
:
Optional
[
torch
.
Tensor
],
name
:
str
)
->
str
:
if
t
is
None
:
return
f
"
{
name
}
: None"
else
:
return
f
"
{
name
}
:
{
t
.
shape
}
{
t
.
dtype
}
{
t
.
device
}
"
@
dataclass
class
Config
:
Ms
:
Union
[
list
[
int
],
int
]
K
:
int
N
:
int
E
:
int
topks
:
Union
[
list
[
int
],
int
]
dtype
:
torch
.
dtype
quant_config
:
Optional
[
FusedMoEQuantConfig
]
prepare_finalize_type
:
mk
.
FusedMoEPrepareAndFinalize
fused_experts_type
:
mk
.
FusedMoEPermuteExpertsUnpermute
fused_moe_chunk_size
:
Optional
[
int
]
world_size
:
int
torch_trace_dir_path
:
Optional
[
str
]
=
None
def
describe
(
self
)
->
str
:
s
=
""
s
+=
"== Config:
\n
"
s
+=
f
" world_size=
{
self
.
world_size
}
\n
"
s
+=
f
" PF=
{
self
.
prepare_finalize_type
.
__name__
}
\n
"
s
+=
f
" FE=
{
self
.
fused_experts_type
.
__name__
}
\n
"
s
+=
f
" topk=
{
self
.
topks
}
\n
"
s
+=
f
" dtype=
{
self
.
dtype
}
\n
"
s
+=
f
" fused_moe_chunk_size=
{
self
.
fused_moe_chunk_size
}
\n
"
s
+=
" Quant:
\n
"
s
+=
f
" fused_moe_chunk_size=
{
self
.
fused_moe_chunk_size
}
\n
"
if
self
.
quant_config
is
not
None
:
s
+=
f
" q_dtype=
{
self
.
quant_dtype
}
\n
"
s
+=
f
" q_block_shape=
{
self
.
quant_block_shape
}
\n
"
s
+=
f
" q_per_out_ch_quant=
{
self
.
is_per_out_ch_quant
}
\n
"
s
+=
f
" q_per_act_token=
{
self
.
is_per_act_token_quant
}
\n
"
else
:
s
+=
" quant=None
\n
"
return
s
@
property
def
M
(
self
)
->
int
:
assert
isinstance
(
self
.
Ms
,
int
)
return
self
.
Ms
@
property
def
quant_dtype
(
self
)
->
Optional
[
torch
.
dtype
]:
if
self
.
quant_config
is
None
:
return
None
return
self
.
quant_config
.
quant_dtype
@
property
def
is_per_act_token_quant
(
self
)
->
bool
:
if
self
.
quant_config
is
None
:
return
False
return
self
.
quant_config
.
per_act_token_quant
@
property
def
is_per_tensor_act_quant
(
self
)
->
bool
:
if
self
.
quant_config
is
None
:
return
False
return
(
not
self
.
is_per_act_token_quant
and
self
.
quant_block_shape
is
None
)
@
property
def
is_per_out_ch_quant
(
self
)
->
bool
:
if
self
.
quant_config
is
None
:
return
False
return
self
.
quant_config
.
per_out_ch_quant
@
property
def
quant_block_shape
(
self
)
->
Optional
[
list
[
int
]]:
if
self
.
quant_config
is
None
:
return
None
return
self
.
quant_config
.
block_shape
@
property
def
topk
(
self
)
->
int
:
assert
isinstance
(
self
.
topks
,
int
)
return
self
.
topks
@
property
def
topk_ids_dtype
(
self
)
->
Optional
[
torch
.
dtype
]:
topk_ids_dtype
=
None
if
self
.
prepare_finalize_type
==
PplxPrepareAndFinalize
:
topk_ids_dtype
=
torch
.
uint32
elif
self
.
prepare_finalize_type
in
[
DeepEPHTPrepareAndFinalize
,
DeepEPLLPrepareAndFinalize
]:
topk_ids_dtype
=
torch
.
int64
return
topk_ids_dtype
@
property
def
num_local_experts
(
self
)
->
int
:
return
self
.
E
//
self
.
world_size
def
make_env_data
(
self
)
->
tuple
[
VllmConfig
,
dict
[
Any
,
Any
]]:
"""
make env data for vllm launch.
"""
vllm_config
=
VllmConfig
()
vllm_config
.
parallel_config
.
data_parallel_size
=
self
.
world_size
vllm_config
.
parallel_config
.
enable_expert_parallel
=
True
env_dict
=
{
"VLLM_ALL2ALL_BACKEND"
:
self
.
all2all_backend
(),
"VLLM_USE_DEEP_GEMM"
:
str
(
int
(
self
.
needs_deep_gemm
())),
}
if
self
.
fused_moe_chunk_size
is
not
None
:
env_dict
.
update
(
{
"VLLM_FUSED_MOE_CHUNK_SIZE"
:
str
(
self
.
fused_moe_chunk_size
)})
return
vllm_config
,
env_dict
def
is_fp8_block_quantized
(
self
):
return
(
self
.
quant_dtype
==
torch
.
float8_e4m3fn
and
self
.
quant_block_shape
is
not
None
)
def
is_batched_prepare_finalize
(
self
):
return
self
.
prepare_finalize_type
in
[
PplxPrepareAndFinalize
,
DeepEPLLPrepareAndFinalize
]
def
is_batched_fused_experts
(
self
):
return
self
.
fused_experts_type
in
[
CutlassExpertsFp8
,
BatchedDeepGemmExperts
,
BatchedTritonExperts
,
NaiveBatchedExperts
,
BatchedTritonOrDeepGemmExperts
]
def
is_standard_fused_experts
(
self
):
return
self
.
fused_experts_type
in
[
CutlassExpertsFp8
,
DeepGemmExperts
,
TritonOrDeepGemmExperts
,
TritonExperts
]
def
is_fe_16bit_supported
(
self
):
return
self
.
fused_experts_type
in
[
BatchedTritonExperts
,
BatchedTritonOrDeepGemmExperts
,
NaiveBatchedExperts
,
TritonExperts
]
def
is_fe_fp8_supported
(
self
):
return
self
.
fused_experts_type
in
[
BatchedDeepGemmExperts
,
BatchedTritonExperts
,
BatchedTritonOrDeepGemmExperts
,
CutlassExpertsFp8
,
DeepGemmExperts
,
TritonExperts
,
TritonOrDeepGemmExperts
,
NaiveBatchedExperts
,
]
def
is_fe_block_fp8_supported
(
self
):
return
self
.
fused_experts_type
in
[
BatchedDeepGemmExperts
,
BatchedTritonOrDeepGemmExperts
,
DeepGemmExperts
,
TritonExperts
,
TritonOrDeepGemmExperts
,
BatchedTritonExperts
,
NaiveBatchedExperts
,
]
def
is_fe_supports_chunking
(
self
):
return
self
.
fused_experts_type
in
[
CutlassExpertsFp8
,
DeepGemmExperts
,
TritonOrDeepGemmExperts
,
TritonExperts
]
def
needs_deep_gemm
(
self
):
return
self
.
fused_experts_type
in
[
BatchedDeepGemmExperts
,
DeepGemmExperts
,
]
def
needs_pplx
(
self
):
return
self
.
prepare_finalize_type
in
[
PplxPrepareAndFinalize
]
def
needs_deep_ep
(
self
):
return
self
.
prepare_finalize_type
in
[
DeepEPHTPrepareAndFinalize
,
DeepEPLLPrepareAndFinalize
]
def
all2all_backend
(
self
):
if
self
.
needs_pplx
():
return
"pplx"
if
self
.
prepare_finalize_type
==
DeepEPHTPrepareAndFinalize
:
return
"deepep_high_throughput"
if
self
.
prepare_finalize_type
==
DeepEPLLPrepareAndFinalize
:
return
"deepep_low_latency"
return
"naive"
def
needs_all2all
(
self
):
return
self
.
prepare_finalize_type
in
[
PplxPrepareAndFinalize
,
DeepEPHTPrepareAndFinalize
,
DeepEPLLPrepareAndFinalize
]
def
is_valid
(
self
):
# Check prepare-finalize and fused-experts compatibility
if
self
.
is_batched_prepare_finalize
():
if
not
self
.
is_batched_fused_experts
():
return
False
else
:
if
not
self
.
is_standard_fused_experts
():
return
False
use_chunking
=
self
.
fused_moe_chunk_size
is
not
None
if
use_chunking
and
not
self
.
is_fe_supports_chunking
():
return
False
# Check quantization sanity
if
(
int
(
self
.
is_per_act_token_quant
)
+
int
(
self
.
is_per_tensor_act_quant
)
+
int
(
self
.
quant_block_shape
is
not
None
))
>
1
:
# invalid quant config
return
False
# check bf16 / fp16 support
is_16bit
=
(
self
.
dtype
.
itemsize
==
2
and
self
.
quant_dtype
is
None
)
if
is_16bit
and
not
self
.
is_fe_16bit_supported
():
return
False
# Check fp8 support
is_fp8
=
self
.
quant_dtype
==
torch
.
float8_e4m3fn
if
is_fp8
and
not
self
.
is_fe_fp8_supported
():
return
False
# Check fp8 block quanization support
is_block_quatized
=
self
.
quant_block_shape
is
not
None
if
is_block_quatized
and
not
is_fp8
:
return
False
if
is_block_quatized
and
not
self
.
is_fe_block_fp8_supported
():
return
False
# deep_gemm only works with block-quantized
if
self
.
needs_deep_gemm
()
and
not
is_block_quatized
:
return
False
# Check dependencies
if
self
.
needs_deep_ep
()
and
not
has_deep_ep
():
return
False
if
self
.
needs_deep_gemm
()
and
not
has_deep_gemm
():
return
False
if
self
.
needs_pplx
()
and
not
has_pplx
():
# noqa: SIM103
return
False
return
True
@
dataclass
class
WeightTensors
:
w1
:
torch
.
Tensor
w2
:
torch
.
Tensor
w1_scale
:
Optional
[
torch
.
Tensor
]
w2_scale
:
Optional
[
torch
.
Tensor
]
def
describe
(
self
):
s
=
""
s
+=
"== Weight Tensors:
\n
"
s
+=
f
' -
{
_describe_tensor
(
self
.
w1
,
"w1"
)
}
\n
'
s
+=
f
' -
{
_describe_tensor
(
self
.
w2
,
"w2"
)
}
\n
'
s
+=
f
' -
{
_describe_tensor
(
self
.
w1_scale
,
"w1_scale"
)
}
\n
'
s
+=
f
' -
{
_describe_tensor
(
self
.
w2_scale
,
"w2_scale"
)
}
\n
'
return
s
def
to_current_device
(
self
):
self
.
w1
=
self
.
w1
.
to
(
device
=
torch
.
cuda
.
current_device
())
self
.
w2
=
self
.
w2
.
to
(
device
=
torch
.
cuda
.
current_device
())
is_quantized
=
self
.
w1
.
dtype
==
torch
.
float8_e4m3fn
if
is_quantized
:
assert
self
.
w1_scale
is
not
None
assert
self
.
w2_scale
is
not
None
self
.
w1_scale
=
self
.
w1_scale
.
to
(
device
=
torch
.
cuda
.
current_device
())
self
.
w2_scale
=
self
.
w2_scale
.
to
(
device
=
torch
.
cuda
.
current_device
())
def
slice_weights
(
self
,
rank
:
int
,
num_local_experts
:
int
)
->
"WeightTensors"
:
s
=
rank
*
num_local_experts
e
=
s
+
num_local_experts
w1
=
self
.
w1
[
s
:
e
,
:,
:]
w2
=
self
.
w2
[
s
:
e
,
:,
:]
is_quantized
=
self
.
w1
.
dtype
==
torch
.
float8_e4m3fn
w1_scale
,
w2_scale
=
(
None
,
None
)
if
is_quantized
:
assert
self
.
w1_scale
is
not
None
assert
self
.
w2_scale
is
not
None
w1_scale
=
self
.
w1_scale
[
s
:
e
,
:,
:]
w2_scale
=
self
.
w2_scale
[
s
:
e
,
:,
:]
return
WeightTensors
(
w1
,
w2
,
w1_scale
,
w2_scale
)
@
staticmethod
def
make
(
config
:
Config
)
->
"WeightTensors"
:
if
config
.
quant_dtype
is
None
:
# just make normal dtype weights
w1
,
w2
=
make_non_quant_weights
(
e
=
config
.
E
,
n
=
config
.
N
,
k
=
config
.
K
,
dtype
=
config
.
dtype
)
return
WeightTensors
(
w1
=
w1
,
w2
=
w2
,
w1_scale
=
None
,
w2_scale
=
None
)
assert
config
.
quant_dtype
==
torch
.
float8_e4m3fn
if
not
config
.
is_fp8_block_quantized
():
w1
,
w2
,
w1_scale
,
w2_scale
=
make_quant_fp8_weights
(
e
=
config
.
E
,
n
=
config
.
N
,
k
=
config
.
K
,
per_out_channel_quant
=
config
.
is_per_out_ch_quant
,
)
return
WeightTensors
(
w1
=
w1
,
w2
=
w2
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
)
assert
config
.
quant_block_shape
is
not
None
w1
,
w2
,
w1_scale
,
w2_scale
=
make_block_quant_fp8_weights
(
e
=
config
.
E
,
n
=
config
.
N
,
k
=
config
.
K
,
block_size
=
config
.
quant_block_shape
,
)
return
WeightTensors
(
w1
=
w1
,
w2
=
w2
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
)
@
dataclass
class
RankTensors
:
hidden_states
:
torch
.
Tensor
hidden_states_scale
:
Optional
[
torch
.
Tensor
]
topk_weights
:
torch
.
Tensor
topk_ids
:
torch
.
Tensor
expert_map
:
Optional
[
torch
.
Tensor
]
quant_config
:
Optional
[
FusedMoEQuantConfig
]
def
describe
(
self
):
s
=
""
s
+=
"== Rank Tensors:
\n
"
s
+=
f
' -
{
_describe_tensor
(
self
.
hidden_states
,
"HS"
)
}
\n
'
s
+=
f
' -
{
_describe_tensor
(
self
.
hidden_states_scale
,
"HS_scale"
)
}
\n
'
s
+=
f
' -
{
_describe_tensor
(
self
.
topk_weights
,
"topk_weights"
)
}
\n
'
s
+=
f
' -
{
_describe_tensor
(
self
.
topk_ids
,
"topk_ids"
)
}
\n
'
s
+=
f
' -
{
_describe_tensor
(
self
.
expert_map
,
"expert_map"
)
}
\n
'
return
s
@
staticmethod
def
make_hidden_states
(
config
:
Config
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
Return hidden_states
"""
m
,
k
,
dtype
=
(
config
.
M
,
config
.
K
,
config
.
dtype
)
a
=
(
torch
.
randn
(
(
m
,
k
),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
/
15.0
)
if
config
.
quant_dtype
is
None
:
return
a
,
None
# We dequant and use that as hidden_states so the tests are stable.
# quantizing and dequantizing yield slightly different results
# depending on the hardware. Here we, quantize and dequantize
# first - so further quantize and dequantize will yield the same
# values.
if
config
.
is_per_tensor_act_quant
:
a_q
,
a_scales
=
ops
.
scaled_fp8_quant
(
a
,
use_per_token_if_dynamic
=
False
)
return
a_q
.
float
().
mul
(
a_scales
).
to
(
dtype
),
a_scales
if
config
.
is_per_act_token_quant
:
a_q
,
a_scales
=
ops
.
scaled_fp8_quant
(
a
,
use_per_token_if_dynamic
=
True
)
return
a_q
.
float
().
mul
(
a_scales
).
to
(
dtype
),
None
assert
config
.
quant_block_shape
is
not
None
block_k
=
config
.
quant_block_shape
[
1
]
a_q
,
a_scales
=
per_token_cast_to_fp8
(
a
,
block_size
=
block_k
)
return
a_q
.
float
().
view
(
(
-
1
,
block_k
)).
mul
(
a_scales
.
view
(
-
1
,
1
)).
view
(
m
,
k
).
to
(
dtype
),
None
@
staticmethod
def
make
(
config
:
Config
,
pgi
:
ProcessGroupInfo
):
dtype
=
config
.
dtype
topk
,
m
,
_
=
(
config
.
topk
,
config
.
M
,
config
.
K
)
hidden_states
,
hidden_states_scale
=
RankTensors
.
make_hidden_states
(
config
)
num_local_experts
,
global_num_experts
=
(
config
.
num_local_experts
,
config
.
E
)
score
=
torch
.
randn
((
m
,
global_num_experts
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
,
_
=
fused_topk
(
hidden_states
,
score
,
topk
,
False
)
topk_ids
=
topk_ids
.
to
(
config
.
topk_ids_dtype
)
# distribute topk_ids evenly
for
mi
in
range
(
m
):
topk_ids
[
mi
]
=
torch
.
randperm
(
config
.
E
)[:
topk
]
topk_ids
=
topk_ids
.
to
(
device
=
torch
.
cuda
.
current_device
())
expert_map
=
None
if
config
.
world_size
>
1
:
expert_map
=
torch
.
full
((
global_num_experts
,
),
fill_value
=-
1
,
dtype
=
torch
.
int32
)
s
=
pgi
.
rank
*
num_local_experts
e
=
s
+
num_local_experts
expert_map
[
s
:
e
]
=
torch
.
tensor
(
list
(
range
(
num_local_experts
)))
expert_map
=
expert_map
.
to
(
device
=
torch
.
cuda
.
current_device
(),
dtype
=
torch
.
int32
)
return
RankTensors
(
hidden_states
=
hidden_states
,
hidden_states_scale
=
hidden_states_scale
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
expert_map
=
expert_map
,
quant_config
=
config
.
quant_config
,
)
def
reference_moe_impl
(
config
:
Config
,
weights
:
WeightTensors
,
rank_tensors
:
RankTensors
)
->
torch
.
Tensor
:
return
torch_experts
(
a
=
rank_tensors
.
hidden_states
,
w1
=
weights
.
w1
,
w2
=
weights
.
w2
,
topk_weight
=
rank_tensors
.
topk_weights
,
topk_ids
=
rank_tensors
.
topk_ids
,
global_num_experts
=
config
.
E
,
expert_map
=
None
,
w1_scale
=
weights
.
w1_scale
,
w2_scale
=
weights
.
w2_scale
,
a1_scale
=
rank_tensors
.
hidden_states_scale
,
quant_dtype
=
config
.
quant_dtype
,
per_act_token_quant
=
config
.
is_per_act_token_quant
,
block_shape
=
config
.
quant_block_shape
,
apply_router_weights_on_input
=
config
.
topk
==
1
)
def
make_fused_experts
(
config
:
Config
,
moe
:
FusedMoEConfig
,
num_dispatchers
:
int
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
use_fp8
=
config
.
quant_dtype
==
torch
.
float8_e4m3fn
batch_kwargs
=
{
"max_num_tokens"
:
moe
.
max_num_tokens
,
"num_dispatchers"
:
num_dispatchers
,
}
quant_kwargs
=
{
"use_fp8_w8a8"
:
use_fp8
,
"use_int8_w8a8"
:
False
,
"use_int8_w8a16"
:
False
,
"use_int4_w4a16"
:
False
,
"block_shape"
:
config
.
quant_block_shape
,
"per_act_token_quant"
:
config
.
is_per_act_token_quant
,
}
deepgemm_kwargs
=
{
"allow_deep_gemm"
:
has_deep_gemm
()}
if
config
.
fused_experts_type
==
BatchedDeepGemmExperts
:
kwargs
=
batch_kwargs
|
{
"block_shape"
:
config
.
quant_block_shape
,
"per_act_token_quant"
:
config
.
is_per_act_token_quant
,
}
print
(
f
"Making BatchedDeepGemmExperts
{
kwargs
}
..."
)
experts
=
BatchedDeepGemmExperts
(
**
kwargs
)
elif
config
.
fused_experts_type
==
BatchedTritonExperts
:
kwargs
=
batch_kwargs
|
quant_kwargs
print
(
f
"Making BatchedTritonExperts
{
kwargs
}
..."
)
experts
=
BatchedTritonExperts
(
**
kwargs
)
elif
config
.
fused_experts_type
==
BatchedTritonOrDeepGemmExperts
:
kwargs
=
batch_kwargs
|
quant_kwargs
|
deepgemm_kwargs
print
(
f
"Making BatchedTritonOrDeepGemmExperts
{
kwargs
}
..."
)
experts
=
BatchedTritonOrDeepGemmExperts
(
**
kwargs
)
elif
config
.
fused_experts_type
==
DeepGemmExperts
:
print
(
"Making DeepGemmExperts () ..."
)
experts
=
DeepGemmExperts
()
elif
config
.
fused_experts_type
==
TritonExperts
:
kwargs
=
quant_kwargs
print
(
f
"Making TritonExperts
{
kwargs
}
..."
)
experts
=
TritonExperts
(
**
kwargs
)
elif
config
.
fused_experts_type
==
TritonOrDeepGemmExperts
:
kwargs
=
quant_kwargs
|
deepgemm_kwargs
print
(
f
"Making TritonOrDeepGemmExperts
{
kwargs
}
..."
)
experts
=
TritonOrDeepGemmExperts
(
**
kwargs
)
elif
config
.
fused_experts_type
==
NaiveBatchedExperts
:
kwargs
=
batch_kwargs
|
quant_kwargs
print
(
f
"Making NaiveBatchedExperts
{
kwargs
}
..."
)
experts
=
NaiveBatchedExperts
(
**
kwargs
)
elif
config
.
fused_experts_type
==
CutlassExpertsFp8
:
use_batched_format
=
config
.
is_batched_prepare_finalize
()
num_experts
=
(
moe
.
num_local_experts
if
use_batched_format
else
moe
.
num_experts
)
kwargs
=
{
"max_experts_per_worker"
:
num_experts
,
"out_dtype"
:
moe
.
in_dtype
,
"per_act_token_quant"
:
config
.
is_per_act_token_quant
,
"per_out_ch_quant"
:
config
.
is_per_out_ch_quant
,
"block_shape"
:
config
.
quant_block_shape
,
"num_dispatchers"
:
num_dispatchers
,
"use_batched_format"
:
use_batched_format
}
print
(
f
"Making CutlassExpertsFp8
{
kwargs
}
..."
)
experts
=
CutlassExpertsFp8
(
**
kwargs
)
return
experts
def
make_modular_kernel
(
config
:
Config
,
vllm_config
:
VllmConfig
)
->
mk
.
FusedMoEModularKernel
:
def
next_power_of_2
(
x
):
import
math
if
x
==
0
:
return
1
return
2
**
math
.
ceil
(
math
.
log2
(
x
))
# make moe config
moe_parallel_config
:
FusedMoEParallelConfig
=
FusedMoEParallelConfig
.
make
(
tp_size_
=
get_tensor_model_parallel_world_size
(),
dp_size_
=
get_dp_group
().
world_size
,
vllm_parallel_config
=
vllm_config
.
parallel_config
,
)
moe
=
FusedMoEConfig
(
num_experts
=
config
.
E
,
experts_per_token
=
config
.
topk
,
hidden_dim
=
config
.
K
,
num_local_experts
=
config
.
num_local_experts
,
moe_parallel_config
=
moe_parallel_config
,
in_dtype
=
config
.
dtype
,
quant_config
=
config
.
quant_config
,
max_num_tokens
=
next_power_of_2
(
config
.
M
),
)
# make modular kernel
prepare_finalize
=
None
if
config
.
needs_all2all
():
prepare_finalize
=
FusedMoEMethodBase
.
maybe_make_prepare_finalize
(
moe
)
assert
prepare_finalize
is
not
None
else
:
prepare_finalize
=
MoEPrepareAndFinalizeNoEP
()
fused_experts
=
make_fused_experts
(
config
,
moe
,
prepare_finalize
.
num_dispatchers
())
modular_kernel
=
mk
.
FusedMoEModularKernel
(
prepare_finalize
=
prepare_finalize
,
fused_experts
=
fused_experts
)
return
modular_kernel
def
run_modular_kernel
(
pgi
:
ProcessGroupInfo
,
vllm_config
:
VllmConfig
,
config
:
Config
,
weights
:
WeightTensors
,
rank_tensors
:
RankTensors
,
)
->
torch
.
Tensor
:
assert
isinstance
(
config
.
Ms
,
int
)
assert
isinstance
(
config
.
topks
,
int
)
# weights for rank
rank_weights
=
weights
.
slice_weights
(
pgi
.
rank
,
config
.
num_local_experts
)
mk
=
make_modular_kernel
(
config
,
vllm_config
)
mk_kwargs
=
{
"hidden_states"
:
rank_tensors
.
hidden_states
.
clone
(
),
# impls might update the tensor in place
"w1"
:
rank_weights
.
w1
,
"w2"
:
rank_weights
.
w2
,
"topk_weights"
:
rank_tensors
.
topk_weights
,
"topk_ids"
:
rank_tensors
.
topk_ids
,
"expert_map"
:
rank_tensors
.
expert_map
,
"w1_scale"
:
rank_weights
.
w1_scale
,
"w2_scale"
:
rank_weights
.
w2_scale
,
"a1_scale"
:
rank_tensors
.
hidden_states_scale
,
"global_num_experts"
:
config
.
E
,
"apply_router_weight_on_input"
:
config
.
topk
==
1
,
}
out
=
mk
.
forward
(
**
mk_kwargs
)
return
out
tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py
0 → 100644
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
from
enum
import
Enum
from
itertools
import
product
from
typing
import
Optional
import
torch
from
tqdm
import
tqdm
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.platforms
import
current_platform
from
.common
import
(
Config
,
RankTensors
,
WeightTensors
,
reference_moe_impl
,
run_modular_kernel
)
from
.mk_objects
import
(
MK_FUSED_EXPERT_TYPES
,
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
,
MK_QUANT_CONFIGS
)
from
.parallel_utils
import
ProcessGroupInfo
,
parallel_launch_with_config
class
Result
(
Enum
):
PASS
=
1
FAIL
=
2
SKIP
=
3
def
rank_worker
(
pgi
:
ProcessGroupInfo
,
vllm_config
:
VllmConfig
,
cpu_group
,
config
:
Config
,
weights
:
WeightTensors
,
):
current_platform
.
seed_everything
(
pgi
.
rank
)
# sanity check
from
vllm
import
envs
if
config
.
fused_moe_chunk_size
is
not
None
:
assert
(
config
.
fused_moe_chunk_size
==
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
)
# get weights to this device
weights
.
to_current_device
()
Ms
=
config
.
Ms
assert
isinstance
(
Ms
,
list
)
TOPKs
=
config
.
topks
assert
isinstance
(
TOPKs
,
list
)
for
m
,
topk
in
product
(
Ms
,
TOPKs
):
print
(
f
"Running m=
{
m
}
, topk=
{
topk
}
..."
)
# override m and topk
cfgx
=
copy
.
deepcopy
(
config
)
cfgx
.
Ms
=
m
cfgx
.
topks
=
topk
# inputs for rank
rank_tensors
=
RankTensors
.
make
(
cfgx
,
pgi
)
# modular kernel out
mk_out
=
run_modular_kernel
(
pgi
,
vllm_config
,
cfgx
,
weights
,
rank_tensors
)
with
set_current_vllm_config
(
vllm_config
):
ref_out
=
reference_moe_impl
(
cfgx
,
weights
,
rank_tensors
)
torch
.
testing
.
assert_close
(
ref_out
,
mk_out
,
atol
=
3e-2
,
rtol
=
3e-2
)
def
make_feature_matrix
(
csv_file_path
:
str
):
from
dataclasses
import
asdict
import
pandas
as
pd
def
add_to_results
(
config
:
Config
,
success
:
Result
,
results_df
:
Optional
[
pd
.
DataFrame
]
=
None
):
config_dict
=
asdict
(
config
)
config_dict
[
'prepare_finalize_type'
]
=
config_dict
[
'prepare_finalize_type'
].
__name__
config_dict
[
'fused_experts_type'
]
=
config_dict
[
'fused_experts_type'
].
__name__
config_dict
[
'per_tensor_act_quant'
]
=
config
.
is_per_tensor_act_quant
quant_config_dict
=
config_dict
[
'quant_config'
]
del
config_dict
[
'quant_config'
]
if
quant_config_dict
is
None
:
quant_config
=
FusedMoEQuantConfig
(
None
)
quant_config_dict
=
asdict
(
quant_config
)
config_dict
|=
quant_config_dict
result_dict
=
config_dict
|
{
'success'
:
success
.
name
}
result_df
=
pd
.
DataFrame
([
result_dict
])
if
results_df
is
None
:
results_df
=
result_df
else
:
results_df
=
pd
.
concat
([
results_df
,
result_df
],
ignore_index
=
True
)
return
results_df
Ms
=
[
64
]
Ks
=
[
7168
]
# hidden sizes
Ns
=
[
2048
]
TOPKs
=
[[
4
,
1
]]
Es
=
[
32
]
DTYPEs
=
[
torch
.
bfloat16
]
PF_TYPES
=
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
FE_TYPES
=
MK_FUSED_EXPERT_TYPES
Q_TYPES
=
MK_QUANT_CONFIGS
combinations
=
list
(
product
(
Ms
,
Ks
,
Ns
,
Es
,
TOPKs
,
DTYPEs
,
PF_TYPES
,
FE_TYPES
,
Q_TYPES
))
results_df
:
Optional
[
pd
.
DataFrame
]
=
None
for
m
,
k
,
n
,
e
,
topks
,
dtype
,
pf_type
,
experts_type
,
quant_config
in
tqdm
(
combinations
):
#noqa: E501
config
=
Config
(
Ms
=
[
m
],
K
=
k
,
N
=
n
,
E
=
e
,
topks
=
topks
,
dtype
=
dtype
,
prepare_finalize_type
=
pf_type
,
fused_experts_type
=
experts_type
,
quant_config
=
quant_config
,
world_size
=
2
,
fused_moe_chunk_size
=
None
)
success
=
None
if
config
.
is_valid
():
print
(
f
"Running config :
{
config
.
describe
()
}
..."
)
try
:
weights
:
WeightTensors
=
WeightTensors
.
make
(
config
)
vllm_config
,
env_dict
=
config
.
make_env_data
()
parallel_launch_with_config
(
config
.
world_size
,
rank_worker
,
vllm_config
,
env_dict
,
config
,
weights
)
success
=
Result
.
PASS
except
Exception
as
_
:
success
=
Result
.
FAIL
else
:
success
=
Result
.
SKIP
results_df
=
add_to_results
(
config
,
success
,
results_df
)
if
results_df
is
not
None
:
results_df
.
to_csv
(
f
"
{
csv_file_path
}
"
)
if
__name__
==
'__main__'
:
import
argparse
from
pathlib
import
Path
parser
=
argparse
.
ArgumentParser
(
description
=
(
"Make ModularKernel feature matrix
\n
"
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.make_feature_matrix "
#noqa: E501
"-f ./feature_matrices/feature_matrix.csv"
))
parser
.
add_argument
(
"-f"
,
"--feature-matrix-csv-file-path"
,
type
=
str
,
required
=
True
,
help
=
"File name to Generate a .csv file"
)
args
=
parser
.
parse_args
()
csv_path
=
args
.
feature_matrix_csv_file_path
assert
csv_path
.
endswith
(
'csv'
),
f
"Need a file path ending with .csv, got
{
csv_path
}
"
assert
Path
(
csv_path
).
parent
.
is_dir
(
),
f
"Cannot find parent directory for
{
Path
(
csv_path
).
parent
}
"
make_feature_matrix
(
args
.
feature_matrix_csv_file_path
)
tests/kernels/moe/modular_kernel_tools/mk_objects.py
0 → 100644
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
# Fused experts and PrepareFinalize imports
from
vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe
import
(
BatchedDeepGemmExperts
)
from
vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe
import
(
# noqa: E501
BatchedTritonOrDeepGemmExperts
)
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
CutlassExpertsFp8
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
DeepGemmExperts
from
vllm.model_executor.layers.fused_moe.fused_batched_moe
import
(
BatchedTritonExperts
,
NaiveBatchedExperts
)
from
vllm.model_executor.layers.fused_moe.layer
import
TritonExperts
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
)
from
vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe
import
(
TritonOrDeepGemmExperts
)
from
vllm.utils
import
has_deep_ep
,
has_pplx
if
has_deep_ep
():
from
vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize
import
(
# noqa: E501
DeepEPHTPrepareAndFinalize
)
from
vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize
import
(
# noqa: E501
DeepEPLLPrepareAndFinalize
)
if
has_pplx
():
from
vllm.model_executor.layers.fused_moe.pplx_prepare_finalize
import
(
PplxPrepareAndFinalize
)
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
=
[]
if
has_pplx
():
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
+=
[
PplxPrepareAndFinalize
]
if
has_deep_ep
():
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
+=
[
DeepEPHTPrepareAndFinalize
,
DeepEPLLPrepareAndFinalize
]
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
=
[
MoEPrepareAndFinalizeNoEP
]
MK_ALL_PREPARE_FINALIZE_TYPES
=
(
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
+
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
)
MK_FUSED_EXPERT_TYPES
=
[
BatchedDeepGemmExperts
,
BatchedTritonExperts
,
NaiveBatchedExperts
,
BatchedTritonOrDeepGemmExperts
,
CutlassExpertsFp8
,
DeepGemmExperts
,
TritonOrDeepGemmExperts
,
TritonExperts
,
]
MK_QUANT_CONFIGS
=
[
None
,
# per-channel / per-column weights and per-tensor activations
FusedMoEQuantConfig
(
quant_dtype
=
torch
.
float8_e4m3fn
,
per_out_ch_quant
=
True
,
per_act_token_quant
=
False
,
block_shape
=
None
),
# per-channel / per-column weights and per-token activations
FusedMoEQuantConfig
(
quant_dtype
=
torch
.
float8_e4m3fn
,
per_out_ch_quant
=
True
,
per_act_token_quant
=
True
,
block_shape
=
None
),
# per-tensor weights and per-tensor activations
FusedMoEQuantConfig
(
quant_dtype
=
torch
.
float8_e4m3fn
,
per_out_ch_quant
=
False
,
per_act_token_quant
=
False
,
block_shape
=
None
),
# per-tensor weights and per-token activations
FusedMoEQuantConfig
(
quant_dtype
=
torch
.
float8_e4m3fn
,
per_out_ch_quant
=
False
,
per_act_token_quant
=
True
,
block_shape
=
None
),
# block-quantized weights and 128 block per-token activations
FusedMoEQuantConfig
(
quant_dtype
=
torch
.
float8_e4m3fn
,
per_out_ch_quant
=
False
,
per_act_token_quant
=
False
,
block_shape
=
[
128
,
128
]),
# TODO (varun) : Should we test the following combinations ?
# block-quantized weights and per-token activations
# block-quantized weights and per-tensor activations
]
tests/kernels/moe/modular_kernel_tools/parallel_utils.py
0 → 100644
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
import
os
import
traceback
from
typing
import
Any
,
Callable
,
Optional
import
torch
from
torch.multiprocessing
import
(
spawn
)
# pyright: ignore[reportPrivateImportUsage]
from
typing_extensions
import
Concatenate
,
ParamSpec
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.distributed
import
(
init_distributed_environment
,
initialize_model_parallel
)
from
vllm.utils
import
get_open_port
## Parallel Processes Utils
P
=
ParamSpec
(
"P"
)
@
dataclasses
.
dataclass
class
ProcessGroupInfo
:
world_size
:
int
world_local_size
:
int
rank
:
int
node_rank
:
int
local_rank
:
int
device
:
torch
.
device
def
_set_vllm_config
(
vllm_config
:
VllmConfig
,
world_size
:
int
,
rank
:
int
,
local_rank
:
int
):
import
tempfile
temp_file
=
tempfile
.
mkstemp
()[
1
]
set_current_vllm_config
(
vllm_config
)
with
set_current_vllm_config
(
vllm_config
):
init_distributed_environment
(
world_size
=
world_size
,
rank
=
rank
,
distributed_init_method
=
f
"file://
{
temp_file
}
"
,
local_rank
=
local_rank
,
backend
=
"nccl"
,
)
initialize_model_parallel
(
tensor_model_parallel_size
=
vllm_config
.
parallel_config
.
tensor_parallel_size
,
pipeline_model_parallel_size
=
vllm_config
.
parallel_config
.
pipeline_parallel_size
,
)
cpu_group
=
torch
.
distributed
.
new_group
(
list
(
range
(
world_size
)),
backend
=
"gloo"
)
return
cpu_group
def
_worker_parallel_launch
(
local_rank
:
int
,
world_size
:
int
,
world_local_size
:
int
,
node_rank
:
int
,
init_method
:
str
,
worker
:
Callable
[
Concatenate
[
ProcessGroupInfo
,
Optional
[
VllmConfig
],
Any
,
P
],
None
],
vllm_config
:
Optional
[
VllmConfig
],
env_dict
:
Optional
[
dict
],
*
args
:
P
.
args
,
**
kwargs
:
P
.
kwargs
,
)
->
None
:
rank
=
node_rank
*
world_local_size
+
local_rank
torch
.
cuda
.
set_device
(
local_rank
)
device
=
torch
.
device
(
"cuda"
,
local_rank
)
torch
.
distributed
.
init_process_group
(
backend
=
"cpu:gloo,cuda:nccl"
,
init_method
=
init_method
,
rank
=
rank
,
world_size
=
world_size
,
device_id
=
device
,
)
barrier
=
torch
.
tensor
([
rank
],
device
=
device
)
torch
.
distributed
.
all_reduce
(
barrier
)
if
env_dict
is
not
None
:
os
.
environ
.
update
(
env_dict
)
cpu_group
=
None
if
vllm_config
is
not
None
:
cpu_group
=
_set_vllm_config
(
vllm_config
,
world_size
,
rank
,
local_rank
)
try
:
worker
(
ProcessGroupInfo
(
world_size
=
world_size
,
world_local_size
=
world_local_size
,
rank
=
rank
,
node_rank
=
node_rank
,
local_rank
=
local_rank
,
device
=
device
,
),
vllm_config
,
cpu_group
,
*
args
,
**
kwargs
,
)
except
Exception
as
ex
:
print
(
ex
)
traceback
.
print_exc
()
raise
finally
:
torch
.
distributed
.
destroy_process_group
()
def
parallel_launch_with_config
(
world_size
:
int
,
worker
:
Callable
[
Concatenate
[
ProcessGroupInfo
,
VllmConfig
,
Any
,
P
],
None
],
vllm_config
:
VllmConfig
,
env_dict
:
dict
[
Any
,
Any
],
*
args
:
P
.
args
,
**
kwargs
:
P
.
kwargs
,
)
->
None
:
assert
not
kwargs
spawn
(
_worker_parallel_launch
,
args
=
(
world_size
,
world_size
,
0
,
f
"tcp://
{
os
.
getenv
(
'LOCALHOST'
,
'localhost'
)
}
:
{
get_open_port
()
}
"
,
worker
,
vllm_config
,
env_dict
,
)
+
args
,
nprocs
=
world_size
,
join
=
True
,
)
tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py
0 → 100644
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
from
itertools
import
product
from
typing
import
Any
,
Callable
import
torch
from
vllm.config
import
VllmConfig
from
vllm.platforms
import
current_platform
from
.common
import
Config
,
RankTensors
,
WeightTensors
,
make_modular_kernel
from
.parallel_utils
import
ProcessGroupInfo
,
parallel_launch_with_config
def
do_profile
(
fn
:
Callable
,
fn_kwargs
:
dict
[
Any
,
Any
],
pgi
:
ProcessGroupInfo
,
config
:
Config
,
num_warmups
:
int
=
5
):
for
_
in
range
(
num_warmups
):
fn
(
**
fn_kwargs
)
with
torch
.
profiler
.
profile
(
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CPU
,
torch
.
profiler
.
ProfilerActivity
.
CUDA
,
],
with_stack
=
True
,
record_shapes
=
True
,
)
as
tprof
:
fn
(
**
fn_kwargs
)
torch
.
cuda
.
synchronize
(
torch
.
cuda
.
current_device
())
# TODO (varun): Add a descriptive trace file name
tprof
.
export_chrome_trace
(
f
"
{
config
.
torch_trace_dir_path
}
/m
{
config
.
M
}
_
{
pgi
.
rank
}
_trace.json"
)
def
profile_modular_kernel
(
pgi
:
ProcessGroupInfo
,
vllm_config
:
VllmConfig
,
config
:
Config
,
weights
:
WeightTensors
,
rank_tensors
:
RankTensors
,
)
->
None
:
assert
isinstance
(
config
.
Ms
,
int
)
assert
isinstance
(
config
.
topks
,
int
)
# weights for rank
rank_weights
=
weights
.
slice_weights
(
pgi
.
rank
,
config
.
num_local_experts
)
# make modular kernel
mk
=
make_modular_kernel
(
config
,
vllm_config
)
mk_kwargs
=
{
"hidden_states"
:
rank_tensors
.
hidden_states
,
"w1"
:
rank_weights
.
w1
,
"w2"
:
rank_weights
.
w2
,
"topk_weights"
:
rank_tensors
.
topk_weights
,
"topk_ids"
:
rank_tensors
.
topk_ids
,
"expert_map"
:
rank_tensors
.
expert_map
,
"w1_scale"
:
rank_weights
.
w1_scale
,
"w2_scale"
:
rank_weights
.
w2_scale
,
"a1_scale"
:
rank_tensors
.
hidden_states_scale
,
"global_num_experts"
:
config
.
E
,
"apply_router_weight_on_input"
:
config
.
topk
==
1
,
}
do_profile
(
mk
.
forward
,
mk_kwargs
,
pgi
,
config
)
def
rank_worker
(
pgi
:
ProcessGroupInfo
,
vllm_config
:
VllmConfig
,
cpu_group
,
config
:
Config
,
weights
:
WeightTensors
,
):
current_platform
.
seed_everything
(
pgi
.
rank
)
# sanity check
from
vllm
import
envs
if
config
.
fused_moe_chunk_size
is
not
None
:
assert
(
config
.
fused_moe_chunk_size
==
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
)
# get weights to this device
weights
.
to_current_device
()
Ms
=
config
.
Ms
assert
isinstance
(
Ms
,
list
)
TOPKs
=
config
.
topks
assert
isinstance
(
TOPKs
,
list
)
for
m
,
topk
in
product
(
Ms
,
TOPKs
):
print
(
f
"Running m=
{
m
}
, topk=
{
topk
}
..."
)
# override m and topk
cfgx
=
copy
.
deepcopy
(
config
)
cfgx
.
Ms
=
m
cfgx
.
topks
=
topk
# inputs for rank
rank_tensors
=
RankTensors
.
make
(
cfgx
,
pgi
)
profile_modular_kernel
(
pgi
,
vllm_config
,
cfgx
,
weights
,
rank_tensors
)
def
run
(
config
:
Config
):
weights
:
WeightTensors
=
WeightTensors
.
make
(
config
)
vllm_config
,
env_dict
=
config
.
make_env_data
()
parallel_launch_with_config
(
config
.
world_size
,
rank_worker
,
vllm_config
,
env_dict
,
config
,
weights
)
if
__name__
==
'__main__'
:
from
.cli_args
import
make_config
,
make_config_arg_parser
parser
=
make_config_arg_parser
(
description
=
(
"Run single prepare-finalize & fused-experts combination test"
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel "
#noqa: E501
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
))
args
=
parser
.
parse_args
()
assert
args
.
torch_trace_dir_path
is
not
None
,
(
"Please pass in a directory to store torch traces"
)
config
=
make_config
(
args
)
run
(
config
)
Prev
1
…
11
12
13
14
15
16
17
18
19
…
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment