Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
31f6b24f
Commit
31f6b24f
authored
Mar 26, 2025
by
zhuwenwen
Browse files
Merge remote-tracking branch 'mirror/v0.8.2' into v0.8.2-ori
parents
89d1dd57
25f560a6
Changes
88
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
780 additions
and
255 deletions
+780
-255
tests/build_cython.py
tests/build_cython.py
+38
-0
tests/compile/test_pass_manager.py
tests/compile/test_pass_manager.py
+46
-16
tests/distributed/test_pipeline_parallel.py
tests/distributed/test_pipeline_parallel.py
+3
-0
tests/entrypoints/openai/test_chat_template.py
tests/entrypoints/openai/test_chat_template.py
+2
-0
tests/entrypoints/openai/test_video.py
tests/entrypoints/openai/test_video.py
+2
-2
tests/entrypoints/test_chat_utils.py
tests/entrypoints/test_chat_utils.py
+62
-2
tests/fastsafetensors_loader/__init__.py
tests/fastsafetensors_loader/__init__.py
+0
-0
tests/fastsafetensors_loader/test_fastsafetensors_loader.py
tests/fastsafetensors_loader/test_fastsafetensors_loader.py
+22
-0
tests/fastsafetensors_loader/test_weight_utils.py
tests/fastsafetensors_loader/test_weight_utils.py
+46
-0
tests/kernels/test_marlin_gemm.py
tests/kernels/test_marlin_gemm.py
+45
-0
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+35
-10
tests/tool_use/utils.py
tests/tool_use/utils.py
+4
-1
tests/tpu/test_compilation.py
tests/tpu/test_compilation.py
+90
-86
tests/v1/engine/test_output_processor.py
tests/v1/engine/test_output_processor.py
+88
-1
tests/v1/entrypoints/llm/test_struct_output_generate.py
tests/v1/entrypoints/llm/test_struct_output_generate.py
+102
-67
tests/v1/sample/test_rejection_sampler.py
tests/v1/sample/test_rejection_sampler.py
+148
-2
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+12
-4
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+1
-1
vllm/compilation/inductor_pass.py
vllm/compilation/inductor_pass.py
+25
-28
vllm/compilation/pass_manager.py
vllm/compilation/pass_manager.py
+9
-35
No files found.
tests/build_cython.py
0 → 100644
View file @
31f6b24f
# SPDX-License-Identifier: Apache-2.0
import
Cython.Compiler.Options
from
Cython.Build
import
cythonize
from
setuptools
import
setup
Cython
.
Compiler
.
Options
.
annotate
=
True
infiles
=
[]
infiles
+=
[
"vllm/engine/llm_engine.py"
,
"vllm/transformers_utils/detokenizer.py"
,
"vllm/engine/output_processor/single_step.py"
,
"vllm/outputs.py"
,
"vllm/engine/output_processor/stop_checker.py"
,
]
infiles
+=
[
"vllm/core/scheduler.py"
,
"vllm/sequence.py"
,
"vllm/core/block_manager.py"
,
]
infiles
+=
[
"vllm/model_executor/layers/sampler.py"
,
"vllm/sampling_params.py"
,
"vllm/utils.py"
,
]
setup
(
ext_modules
=
cythonize
(
infiles
,
annotate
=
False
,
force
=
True
,
compiler_directives
=
{
'language_level'
:
"3"
,
'infer_types'
:
True
}))
# example usage: python3 build_cython.py build_ext --inplace
tests/compile/test_pass_manager.py
View file @
31f6b24f
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
copy
import
pickle
import
pytest
import
pytest
import
torch
import
torch
...
@@ -10,32 +9,63 @@ from vllm.compilation.pass_manager import PostGradPassManager
...
@@ -10,32 +9,63 @@ from vllm.compilation.pass_manager import PostGradPassManager
from
vllm.config
import
CompilationConfig
from
vllm.config
import
CompilationConfig
# dummy custom pass that doesn't inherit
def
simple_callable
(
graph
:
torch
.
fx
.
Graph
):
def
simple_callable
(
graph
:
torch
.
fx
.
Graph
):
pass
pass
callable_uuid
=
CallableInductorPass
(
simple_callable
,
# Should fail to add directly to the pass manager
InductorPass
.
hash_source
(
__file__
))
def
test_bad_callable
():
config
=
CompilationConfig
().
pass_config
pass_manager
=
PostGradPassManager
()
pass_manager
.
configure
(
config
)
with
pytest
.
raises
(
AssertionError
):
pass_manager
.
add
(
simple_callable
)
# noqa, type wrong on purpose
# Pass that inherits from InductorPass
class
ProperPass
(
InductorPass
):
def
__call__
(
self
,
graph
:
torch
.
fx
.
graph
.
Graph
)
->
None
:
pass
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"
works,
callable"
,
"callable"
,
[
[
(
False
,
simple_callable
),
ProperPass
(),
(
True
,
callable_uuid
),
# Can also wrap callables in CallableInductorPass for compliance
(
True
,
CallableInductorPass
(
simple_callable
)),
CallableInductorPass
(
simple_callable
),
CallableInductorPass
(
simple_callable
,
InductorPass
.
hash_source
(
__file__
))
],
],
)
)
def
test_pass_manager
(
works
:
bool
,
callable
):
def
test_pass_manager
_uuid
(
callable
):
config
=
CompilationConfig
().
pass_config
config
=
CompilationConfig
().
pass_config
pass_manager
=
PostGradPassManager
()
pass_manager
=
PostGradPassManager
()
pass_manager
.
configure
(
config
)
pass_manager
.
configure
(
config
)
# Try to add the callable to the pass manager
# Check that UUID is different if the same pass is added 2x
if
works
:
pass_manager
.
add
(
callable
)
pass_manager
.
add
(
callable
)
uuid1
=
pass_manager
.
uuid
()
pickle
.
dumps
(
pass_manager
)
pass_manager
.
add
(
callable
)
else
:
uuid2
=
pass_manager
.
uuid
()
with
pytest
.
raises
(
AssertionError
):
assert
uuid1
!=
uuid2
pass_manager
.
add
(
callable
)
# UUID should be the same as the original one,
# as we constructed in the same way.
pass_manager2
=
PostGradPassManager
()
pass_manager2
.
configure
(
config
)
pass_manager2
.
add
(
callable
)
assert
uuid1
==
pass_manager2
.
uuid
()
# UUID should be different due to config change
config2
=
copy
.
deepcopy
(
config
)
config2
.
enable_fusion
=
not
config2
.
enable_fusion
pass_manager3
=
PostGradPassManager
()
pass_manager3
.
configure
(
config2
)
pass_manager3
.
add
(
callable
)
assert
uuid1
!=
pass_manager3
.
uuid
()
tests/distributed/test_pipeline_parallel.py
View file @
31f6b24f
...
@@ -175,6 +175,8 @@ TEXT_GENERATION_MODELS = {
...
@@ -175,6 +175,8 @@ TEXT_GENERATION_MODELS = {
"inceptionai/jais-13b-chat"
:
PPTestSettings
.
fast
(),
"inceptionai/jais-13b-chat"
:
PPTestSettings
.
fast
(),
"ai21labs/Jamba-tiny-dev"
:
PPTestSettings
.
fast
(),
"ai21labs/Jamba-tiny-dev"
:
PPTestSettings
.
fast
(),
"meta-llama/Llama-3.2-1B-Instruct"
:
PPTestSettings
.
detailed
(),
"meta-llama/Llama-3.2-1B-Instruct"
:
PPTestSettings
.
detailed
(),
# Tests TransformersModel
"ArthurZ/Ilama-3.2-1B"
:
PPTestSettings
.
fast
(),
"openbmb/MiniCPM-2B-sft-bf16"
:
PPTestSettings
.
fast
(),
"openbmb/MiniCPM-2B-sft-bf16"
:
PPTestSettings
.
fast
(),
"openbmb/MiniCPM3-4B"
:
PPTestSettings
.
fast
(),
"openbmb/MiniCPM3-4B"
:
PPTestSettings
.
fast
(),
# Uses Llama
# Uses Llama
...
@@ -243,6 +245,7 @@ TEST_MODELS = [
...
@@ -243,6 +245,7 @@ TEST_MODELS = [
# [LANGUAGE GENERATION]
# [LANGUAGE GENERATION]
"microsoft/Phi-3.5-MoE-instruct"
,
"microsoft/Phi-3.5-MoE-instruct"
,
"meta-llama/Llama-3.2-1B-Instruct"
,
"meta-llama/Llama-3.2-1B-Instruct"
,
# "ArthurZ/Ilama-3.2-1B", NOTE: Uncomment after #13905
"ibm/PowerLM-3b"
,
"ibm/PowerLM-3b"
,
# [LANGUAGE EMBEDDING]
# [LANGUAGE EMBEDDING]
"intfloat/e5-mistral-7b-instruct"
,
"intfloat/e5-mistral-7b-instruct"
,
...
...
tests/entrypoints/openai/test_chat_template.py
View file @
31f6b24f
...
@@ -107,8 +107,10 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
...
@@ -107,8 +107,10 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
# Call the function and get the result
# Call the function and get the result
result
=
apply_hf_chat_template
(
result
=
apply_hf_chat_template
(
tokenizer
,
tokenizer
,
trust_remote_code
=
True
,
conversation
=
mock_request
.
messages
,
conversation
=
mock_request
.
messages
,
chat_template
=
mock_request
.
chat_template
or
template_content
,
chat_template
=
mock_request
.
chat_template
or
template_content
,
tools
=
None
,
add_generation_prompt
=
mock_request
.
add_generation_prompt
,
add_generation_prompt
=
mock_request
.
add_generation_prompt
,
continue_final_message
=
mock_request
.
continue_final_message
,
continue_final_message
=
mock_request
.
continue_final_message
,
)
)
...
...
tests/entrypoints/openai/test_video.py
View file @
31f6b24f
...
@@ -87,7 +87,7 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI,
...
@@ -87,7 +87,7 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI,
choice
=
chat_completion
.
choices
[
0
]
choice
=
chat_completion
.
choices
[
0
]
assert
choice
.
finish_reason
==
"length"
assert
choice
.
finish_reason
==
"length"
assert
chat_completion
.
usage
==
openai
.
types
.
CompletionUsage
(
assert
chat_completion
.
usage
==
openai
.
types
.
CompletionUsage
(
completion_tokens
=
10
,
prompt_tokens
=
62
99
,
total_tokens
=
6
309
)
completion_tokens
=
10
,
prompt_tokens
=
62
87
,
total_tokens
=
6
297
)
message
=
choice
.
message
message
=
choice
.
message
message
=
chat_completion
.
choices
[
0
].
message
message
=
chat_completion
.
choices
[
0
].
message
...
@@ -180,7 +180,7 @@ async def test_single_chat_session_video_base64encoded(
...
@@ -180,7 +180,7 @@ async def test_single_chat_session_video_base64encoded(
choice
=
chat_completion
.
choices
[
0
]
choice
=
chat_completion
.
choices
[
0
]
assert
choice
.
finish_reason
==
"length"
assert
choice
.
finish_reason
==
"length"
assert
chat_completion
.
usage
==
openai
.
types
.
CompletionUsage
(
assert
chat_completion
.
usage
==
openai
.
types
.
CompletionUsage
(
completion_tokens
=
10
,
prompt_tokens
=
62
99
,
total_tokens
=
6
309
)
completion_tokens
=
10
,
prompt_tokens
=
62
87
,
total_tokens
=
6
297
)
message
=
choice
.
message
message
=
choice
.
message
message
=
chat_completion
.
choices
[
0
].
message
message
=
chat_completion
.
choices
[
0
].
message
...
...
tests/entrypoints/test_chat_utils.py
View file @
31f6b24f
...
@@ -4,10 +4,13 @@ import warnings
...
@@ -4,10 +4,13 @@ import warnings
from
typing
import
Optional
from
typing
import
Optional
import
pytest
import
pytest
from
packaging.version
import
Version
from
transformers
import
__version__
as
TRANSFORMERS_VERSION
from
vllm.assets.image
import
ImageAsset
from
vllm.assets.image
import
ImageAsset
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.entrypoints.chat_utils
import
(
_try_extract_ast
,
load_chat_template
,
from
vllm.entrypoints.chat_utils
import
(
_resolve_hf_chat_template
,
_try_extract_ast
,
load_chat_template
,
parse_chat_messages
,
parse_chat_messages
,
parse_chat_messages_futures
,
parse_chat_messages_futures
,
resolve_chat_template_content_format
)
resolve_chat_template_content_format
)
...
@@ -23,8 +26,10 @@ EXAMPLES_DIR = VLLM_PATH / "examples"
...
@@ -23,8 +26,10 @@ EXAMPLES_DIR = VLLM_PATH / "examples"
PHI3V_MODEL_ID
=
"microsoft/Phi-3.5-vision-instruct"
PHI3V_MODEL_ID
=
"microsoft/Phi-3.5-vision-instruct"
ULTRAVOX_MODEL_ID
=
"fixie-ai/ultravox-v0_5-llama-3_2-1b"
ULTRAVOX_MODEL_ID
=
"fixie-ai/ultravox-v0_5-llama-3_2-1b"
QWEN2VL_MODEL_ID
=
"Qwen/Qwen2-VL-2B-Instruct"
QWEN2VL_MODEL_ID
=
"Qwen/Qwen2-VL-2B-Instruct"
QWEN25VL_MODEL_ID
=
"Qwen/Qwen2.5-VL-3B-Instruct"
MLLAMA_MODEL_ID
=
"meta-llama/Llama-3.2-11B-Vision-Instruct"
MLLAMA_MODEL_ID
=
"meta-llama/Llama-3.2-11B-Vision-Instruct"
LLAMA_GUARD_MODEL_ID
=
"meta-llama/Llama-Guard-3-1B"
LLAMA_GUARD_MODEL_ID
=
"meta-llama/Llama-Guard-3-1B"
HERMES_MODEL_ID
=
"NousResearch/Hermes-3-Llama-3.1-8B"
@
pytest
.
fixture
(
scope
=
"function"
)
@
pytest
.
fixture
(
scope
=
"function"
)
...
@@ -703,25 +708,70 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
...
@@ -703,25 +708,70 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
vllm_result
=
apply_hf_chat_template
(
vllm_result
=
apply_hf_chat_template
(
tokenizer
,
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
,
conversation
=
conversation
,
conversation
=
conversation
,
chat_template
=
None
,
chat_template
=
None
,
tools
=
None
,
add_generation_prompt
=
True
,
add_generation_prompt
=
True
,
)
)
assert
hf_result
==
vllm_result
assert
hf_result
==
vllm_result
@
pytest
.
mark
.
parametrize
(
"model"
,
[
QWEN2VL_MODEL_ID
,
# tokenizer.chat_template is of type str
HERMES_MODEL_ID
,
# tokenizer.chat_template is of type dict
])
@
pytest
.
mark
.
parametrize
(
"use_tools"
,
[
True
,
False
])
def
test_resolve_hf_chat_template
(
sample_json_schema
,
model
,
use_tools
):
"""checks that chat_template is a dict type for HF models."""
# Build the tokenizer group and grab the underlying tokenizer
tokenizer_group
=
TokenizerGroup
(
model
,
enable_lora
=
False
,
max_num_seqs
=
5
,
max_input_length
=
None
,
)
tokenizer
=
tokenizer_group
.
tokenizer
tools
=
[{
"type"
:
"function"
,
"function"
:
{
"name"
:
"dummy_function_name"
,
"description"
:
"This is a dummy function"
,
"parameters"
:
sample_json_schema
}
}]
if
use_tools
else
None
# Test detecting the tokenizer's chat_template
chat_template
=
_resolve_hf_chat_template
(
tokenizer
,
chat_template
=
None
,
tools
=
tools
,
trust_remote_code
=
True
,
)
assert
isinstance
(
chat_template
,
str
)
# yapf: disable
# yapf: disable
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"model"
,
"expected_format"
),
(
"model"
,
"expected_format"
),
[(
PHI3V_MODEL_ID
,
"string"
),
[(
PHI3V_MODEL_ID
,
"string"
),
(
QWEN2VL_MODEL_ID
,
"openai"
),
(
QWEN2VL_MODEL_ID
,
"openai"
),
(
QWEN25VL_MODEL_ID
,
"openai"
),
(
ULTRAVOX_MODEL_ID
,
"string"
),
(
ULTRAVOX_MODEL_ID
,
"string"
),
(
MLLAMA_MODEL_ID
,
"openai"
),
(
MLLAMA_MODEL_ID
,
"openai"
),
(
LLAMA_GUARD_MODEL_ID
,
"openai"
)],
(
LLAMA_GUARD_MODEL_ID
,
"openai"
)],
)
)
# yapf: enable
# yapf: enable
def
test_resolve_content_format_hf_defined
(
model
,
expected_format
):
def
test_resolve_content_format_hf_defined
(
model
,
expected_format
):
if
model
==
QWEN25VL_MODEL_ID
and
Version
(
TRANSFORMERS_VERSION
)
<
Version
(
"4.49.0"
):
pytest
.
skip
(
"Qwen2.5-VL requires transformers>=4.49.0"
)
tokenizer_group
=
TokenizerGroup
(
tokenizer_group
=
TokenizerGroup
(
model
,
model
,
enable_lora
=
False
,
enable_lora
=
False
,
...
@@ -730,7 +780,13 @@ def test_resolve_content_format_hf_defined(model, expected_format):
...
@@ -730,7 +780,13 @@ def test_resolve_content_format_hf_defined(model, expected_format):
)
)
tokenizer
=
tokenizer_group
.
tokenizer
tokenizer
=
tokenizer_group
.
tokenizer
chat_template
=
tokenizer
.
chat_template
# Test detecting the tokenizer's chat_template
chat_template
=
_resolve_hf_chat_template
(
tokenizer
,
chat_template
=
None
,
tools
=
None
,
trust_remote_code
=
True
,
)
assert
isinstance
(
chat_template
,
str
)
assert
isinstance
(
chat_template
,
str
)
print
(
"[TEXT]"
)
print
(
"[TEXT]"
)
...
@@ -740,8 +796,10 @@ def test_resolve_content_format_hf_defined(model, expected_format):
...
@@ -740,8 +796,10 @@ def test_resolve_content_format_hf_defined(model, expected_format):
resolved_format
=
resolve_chat_template_content_format
(
resolved_format
=
resolve_chat_template_content_format
(
None
,
# Test detecting the tokenizer's chat_template
None
,
# Test detecting the tokenizer's chat_template
None
,
"auto"
,
"auto"
,
tokenizer
,
tokenizer
,
trust_remote_code
=
True
,
)
)
assert
resolved_format
==
expected_format
assert
resolved_format
==
expected_format
...
@@ -791,8 +849,10 @@ def test_resolve_content_format_examples(template_path, expected_format):
...
@@ -791,8 +849,10 @@ def test_resolve_content_format_examples(template_path, expected_format):
resolved_format
=
resolve_chat_template_content_format
(
resolved_format
=
resolve_chat_template_content_format
(
chat_template
,
chat_template
,
None
,
"auto"
,
"auto"
,
dummy_tokenizer
,
dummy_tokenizer
,
trust_remote_code
=
True
,
)
)
assert
resolved_format
==
expected_format
assert
resolved_format
==
expected_format
tests/fastsafetensors_loader/__init__.py
0 → 100644
View file @
31f6b24f
tests/fastsafetensors_loader/test_fastsafetensors_loader.py
0 → 100644
View file @
31f6b24f
# SPDX-License-Identifier: Apache-2.0
from
vllm
import
SamplingParams
from
vllm.config
import
LoadFormat
test_model
=
"openai-community/gpt2"
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
# Create a sampling params object.
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
seed
=
0
)
def
test_model_loader_download_files
(
vllm_runner
):
with
vllm_runner
(
test_model
,
load_format
=
LoadFormat
.
FASTSAFETENSORS
)
as
llm
:
deserialized_outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
assert
deserialized_outputs
tests/fastsafetensors_loader/test_weight_utils.py
0 → 100644
View file @
31f6b24f
# SPDX-License-Identifier: Apache-2.0
import
glob
import
tempfile
import
huggingface_hub.constants
import
torch
from
vllm.model_executor.model_loader.weight_utils
import
(
download_weights_from_hf
,
fastsafetensors_weights_iterator
,
safetensors_weights_iterator
)
def
test_fastsafetensors_model_loader
():
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
=
False
download_weights_from_hf
(
"openai-community/gpt2"
,
allow_patterns
=
[
"*.safetensors"
],
cache_dir
=
tmpdir
)
safetensors
=
glob
.
glob
(
f
"
{
tmpdir
}
/**/*.safetensors"
,
recursive
=
True
)
assert
len
(
safetensors
)
>
0
fastsafetensors_tensors
=
{}
hf_safetensors_tensors
=
{}
for
name
,
tensor
in
fastsafetensors_weights_iterator
(
safetensors
,
True
):
fastsafetensors_tensors
[
name
]
=
tensor
for
name
,
tensor
in
safetensors_weights_iterator
(
safetensors
,
True
):
hf_safetensors_tensors
[
name
]
=
tensor
assert
len
(
fastsafetensors_tensors
)
==
len
(
hf_safetensors_tensors
)
for
name
,
fastsafetensors_tensor
in
fastsafetensors_tensors
.
items
():
fastsafetensors_tensor
=
fastsafetensors_tensor
.
to
(
'cpu'
)
assert
fastsafetensors_tensor
.
dtype
==
hf_safetensors_tensors
[
name
].
dtype
assert
fastsafetensors_tensor
.
shape
==
hf_safetensors_tensors
[
name
].
shape
assert
torch
.
all
(
fastsafetensors_tensor
.
eq
(
hf_safetensors_tensors
[
name
]))
if
__name__
==
"__main__"
:
test_fastsafetensors_model_loader
()
tests/kernels/test_marlin_gemm.py
View file @
31f6b24f
...
@@ -606,6 +606,51 @@ def test_marlin_qqq_gemm(
...
@@ -606,6 +606,51 @@ def test_marlin_qqq_gemm(
assert
max_diff
<
0.04
assert
max_diff
<
0.04
def
test_marlin_gemm_subset_input
():
quant_type
=
scalar_types
.
uint4b8
group_size
=
128
size_m
,
size_k
,
size_n
=
32
,
1024
,
2048
big_m
=
size_m
*
2
big_k
=
size_k
*
2
a_input
=
rand_data
((
big_m
,
big_k
))[
8
:
size_m
+
8
,
8
:
size_k
+
8
]
b_weight
=
rand_data
((
size_k
,
size_n
))
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
_
=
marlin_quantize
(
b_weight
,
quant_type
,
group_size
,
False
)
marlin_zp
=
marlin_make_empty_g_idx
(
marlin_s
.
device
)
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
output
=
ops
.
gptq_marlin_gemm
(
a_input
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_idx
,
sort_indices
,
workspace
.
scratch
,
quant_type
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
=
True
,
has_zp
=
False
,
use_atomic_add
=
False
,
use_fp32_reduce
=
True
,
is_zp_float
=
False
,
)
output_ref
=
torch
.
matmul
(
a_input
,
w_ref
)
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
assert
max_diff
<
0.04
def
test_marlin_gemm_opcheck
():
def
test_marlin_gemm_opcheck
():
size_m
=
2048
size_m
=
2048
size_n
=
4096
size_n
=
4096
...
...
tests/kernels/test_moe.py
View file @
31f6b24f
...
@@ -3,8 +3,11 @@
...
@@ -3,8 +3,11 @@
Run `pytest tests/kernels/test_moe.py`.
Run `pytest tests/kernels/test_moe.py`.
"""
"""
import
pytest
import
pytest
import
torch
import
torch
from
torch.nn
import
Parameter
from
torch.nn
import
functional
as
F
from
transformers
import
MixtralConfig
from
transformers
import
MixtralConfig
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
...
@@ -37,6 +40,7 @@ TOP_KS = [2, 6]
...
@@ -37,6 +40,7 @@ TOP_KS = [2, 6]
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
EP_SIZE
)
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
EP_SIZE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"padding"
,
[
True
,
False
])
def
test_fused_moe
(
def
test_fused_moe
(
m
:
int
,
m
:
int
,
n
:
int
,
n
:
int
,
...
@@ -45,6 +49,7 @@ def test_fused_moe(
...
@@ -45,6 +49,7 @@ def test_fused_moe(
topk
:
int
,
topk
:
int
,
ep_size
:
int
,
ep_size
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
padding
:
bool
,
):
):
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
...
@@ -65,16 +70,7 @@ def test_fused_moe(
...
@@ -65,16 +70,7 @@ def test_fused_moe(
else
:
else
:
e_map
=
None
e_map
=
None
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
global_num_experts
=
e
,
expert_map
=
e_map
,
renormalize
=
False
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
,
e_map
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
,
e_map
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
iterative_output
=
iterative_moe
(
a
,
iterative_output
=
iterative_moe
(
a
,
w1
,
w1
,
w2
,
w2
,
...
@@ -83,6 +79,23 @@ def test_fused_moe(
...
@@ -83,6 +79,23 @@ def test_fused_moe(
global_num_experts
=
e
,
global_num_experts
=
e
,
expert_map
=
e_map
,
expert_map
=
e_map
,
renormalize
=
False
)
renormalize
=
False
)
# Pad the weight if moe padding is enabled
if
padding
:
w1
=
F
.
pad
(
w1
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
]
torch
.
cuda
.
empty_cache
()
w2
=
F
.
pad
(
w2
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
]
torch
.
cuda
.
empty_cache
()
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
global_num_experts
=
e
,
expert_map
=
e_map
,
renormalize
=
False
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
iterative_output
,
torch
.
testing
.
assert_close
(
iterative_output
,
torch_output
,
torch_output
,
atol
=
2e-2
,
atol
=
2e-2
,
...
@@ -202,8 +215,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
...
@@ -202,8 +215,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
@
pytest
.
mark
.
parametrize
(
"dtype"
,
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"padding"
,
[
True
,
False
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_mixtral_moe
(
dtype
:
torch
.
dtype
):
def
test_mixtral_moe
(
dtype
:
torch
.
dtype
,
padding
:
bool
):
"""Make sure our Mixtral MoE implementation agrees with the one from
"""Make sure our Mixtral MoE implementation agrees with the one from
huggingface."""
huggingface."""
...
@@ -233,6 +247,17 @@ def test_mixtral_moe(dtype: torch.dtype):
...
@@ -233,6 +247,17 @@ def test_mixtral_moe(dtype: torch.dtype):
# vLLM uses 1D query [num_tokens, hidden_dim]
# vLLM uses 1D query [num_tokens, hidden_dim]
vllm_inputs
=
hf_inputs
.
flatten
(
0
,
1
)
vllm_inputs
=
hf_inputs
.
flatten
(
0
,
1
)
# Pad the weight if moe padding is enabled
if
padding
:
vllm_moe
.
experts
.
w13_weight
=
Parameter
(
F
.
pad
(
vllm_moe
.
experts
.
w13_weight
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
],
requires_grad
=
False
)
torch
.
cuda
.
empty_cache
()
vllm_moe
.
experts
.
w2_weight
=
Parameter
(
F
.
pad
(
vllm_moe
.
experts
.
w2_weight
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
],
requires_grad
=
False
)
torch
.
cuda
.
empty_cache
()
# Run forward passes for both MoE blocks
# Run forward passes for both MoE blocks
hf_states
,
_
=
hf_moe
.
forward
(
hf_inputs
)
hf_states
,
_
=
hf_moe
.
forward
(
hf_inputs
)
vllm_states
=
vllm_moe
.
forward
(
vllm_inputs
)
vllm_states
=
vllm_moe
.
forward
(
vllm_inputs
)
...
...
tests/tool_use/utils.py
View file @
31f6b24f
...
@@ -39,7 +39,10 @@ def ensure_system_prompt(messages: list[dict[str, Any]],
...
@@ -39,7 +39,10 @@ def ensure_system_prompt(messages: list[dict[str, Any]],
# universal args for all models go here. also good if you need to test locally
# universal args for all models go here. also good if you need to test locally
# and change type or KV cache quantization or something.
# and change type or KV cache quantization or something.
ARGS
:
list
[
str
]
=
[
"--enable-auto-tool-choice"
,
"--max-model-len"
,
"1024"
]
ARGS
:
list
[
str
]
=
[
"--enable-auto-tool-choice"
,
"--max-model-len"
,
"1024"
,
"--max-num-seqs"
,
"256"
]
CONFIGS
:
dict
[
str
,
ServerConfig
]
=
{
CONFIGS
:
dict
[
str
,
ServerConfig
]
=
{
"hermes"
:
{
"hermes"
:
{
...
...
tests/tpu/test_compilation.py
View file @
31f6b24f
...
@@ -5,92 +5,96 @@ import os
...
@@ -5,92 +5,96 @@ import os
import
tempfile
import
tempfile
import
depyf
import
depyf
import
pytest
from
vllm.config
import
CompilationLevel
from
vllm.config
import
CompilationLevel
temp_dir
=
tempfile
.
mkdtemp
()
with
depyf
.
prepare_debug
(
temp_dir
):
@
pytest
.
mark
.
skip
(
reason
=
"Not working; needs investigation."
)
from
vllm
import
LLM
,
SamplingParams
def
test_tpu_compilation
():
temp_dir
=
tempfile
.
mkdtemp
()
prompts
=
[
with
depyf
.
prepare_debug
(
temp_dir
):
"A robot may not injure a human being"
,
from
vllm
import
LLM
,
SamplingParams
"It is only with the heart that one can see rightly;"
,
"The greatest glory in living lies not in never falling,"
,
prompts
=
[
]
"A robot may not injure a human being"
,
answers
=
[
"It is only with the heart that one can see rightly;"
,
" or, through inaction, allow a human being to come to harm."
,
"The greatest glory in living lies not in never falling,"
,
" what is essential is invisible to the eye."
,
]
" but in rising every time we fall."
,
answers
=
[
]
" or, through inaction, allow a human being to come to harm."
,
N
=
1
" what is essential is invisible to the eye."
,
# Currently, top-p sampling is disabled. `top_p` should be 1.0.
" but in rising every time we fall."
,
sampling_params
=
SamplingParams
(
temperature
=
0.7
,
]
top_p
=
1.0
,
N
=
1
n
=
N
,
# Currently, top-p sampling is disabled. `top_p` should be 1.0.
max_tokens
=
16
)
sampling_params
=
SamplingParams
(
temperature
=
0.7
,
top_p
=
1.0
,
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
n
=
N
,
# In real workloads, `enforace_eager` should be `False`.
max_tokens
=
16
)
# disable custom dispatcher, let Dynamo takes over
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
# all the control
# In real workloads, `enforace_eager` should be `False`.
llm
=
LLM
(
model
=
"Qwen/Qwen2.5-1.5B-Instruct"
,
max_model_len
=
512
,
# disable custom dispatcher, let Dynamo takes over
max_num_seqs
=
64
,
# all the control
enforce_eager
=
True
,
llm
=
LLM
(
model
=
"Qwen/Qwen2.5-1.5B-Instruct"
,
compilation_config
=
{
"level"
:
CompilationLevel
.
DYNAMO_AS_IS
})
max_model_len
=
512
,
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
max_num_seqs
=
64
,
for
output
,
answer
in
zip
(
outputs
,
answers
):
enforce_eager
=
True
,
prompt
=
output
.
prompt
compilation_config
=
{
"level"
:
CompilationLevel
.
DYNAMO_AS_IS
})
generated_text
=
output
.
outputs
[
0
].
text
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
for
output
,
answer
in
zip
(
outputs
,
answers
):
assert
generated_text
.
startswith
(
answer
)
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
compiled_codes
=
sorted
(
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
glob
.
glob
(
os
.
path
.
join
(
temp_dir
,
"__transformed_code*.py"
)))
assert
generated_text
.
startswith
(
answer
)
for
i
,
compiled_code
in
enumerate
(
compiled_codes
):
compiled_codes
=
sorted
(
print
(
"{} file: {}"
.
format
(
i
+
1
,
compiled_code
))
glob
.
glob
(
os
.
path
.
join
(
temp_dir
,
"__transformed_code*.py"
)))
# We should only trigger Dynamo compilation 4 times:
for
i
,
compiled_code
in
enumerate
(
compiled_codes
):
# 1. forward pass (symbolic)
print
(
"{} file: {}"
.
format
(
i
+
1
,
compiled_code
))
# 2. compute_logits (symbolic)
# 3. forward pass (shape 16)
# We should only trigger Dynamo compilation 4 times:
# 4. forward pass (shape 32)
# 1. forward pass (symbolic)
# and later calls should not trigger Dynamo compilation again.
# 2. compute_logits (symbolic)
# NOTE: It might still trigger XLA compilation.
# 3. forward pass (shape 16)
# 4. forward pass (shape 32)
# Check we have 4 compiled codes
# and later calls should not trigger Dynamo compilation again.
assert
len
(
compiled_codes
)
==
4
# NOTE: It might still trigger XLA compilation.
kv_cache_prefix
=
"kv_cache"
# Check we have 4 compiled codes
attn_prefix
=
"ragged_paged_attention"
assert
len
(
compiled_codes
)
==
4
# Check all the compilations are as expected
kv_cache_prefix
=
"kv_cache"
compiled_fns
=
sorted
(
attn_prefix
=
"ragged_paged_attention"
glob
.
glob
(
os
.
path
.
join
(
temp_dir
,
"__compiled_fn*Captured*.py"
)))
# Check all the compilations are as expected
for
i
,
compiled_fn
in
enumerate
(
compiled_fns
):
compiled_fns
=
sorted
(
print
(
"{} file: {}"
.
format
(
i
+
1
,
compiled_fn
))
glob
.
glob
(
os
.
path
.
join
(
temp_dir
,
"__compiled_fn*Captured*.py"
)))
# The first compilation is symbolic, so it should not have any kv_caches
for
i
,
compiled_fn
in
enumerate
(
compiled_fns
):
with
open
(
compiled_fns
[
0
])
as
f
:
print
(
"{} file: {}"
.
format
(
i
+
1
,
compiled_fn
))
content
=
f
.
read
()
assert
kv_cache_prefix
not
in
content
# The first compilation is symbolic, so it should not have any kv_caches
with
open
(
compiled_fns
[
0
])
as
f
:
# The second compilation is symbolic, so it should not have any kv_caches
content
=
f
.
read
()
with
open
(
compiled_fns
[
1
])
as
f
:
assert
kv_cache_prefix
not
in
content
content
=
f
.
read
()
assert
kv_cache_prefix
not
in
content
# The second compilation is symbolic, so it should not have any kv_caches
with
open
(
compiled_fns
[
1
])
as
f
:
# The third compilation is shape 16, so it should have kv_caches and the
content
=
f
.
read
()
# ragged_paged_attention
assert
kv_cache_prefix
not
in
content
with
open
(
compiled_fns
[
2
])
as
f
:
content
=
f
.
read
()
# The third compilation is shape 16, so it should have kv_caches and the
assert
(
kv_cache_prefix
in
content
and
attn_prefix
in
content
)
# ragged_paged_attention
with
open
(
compiled_fns
[
2
])
as
f
:
# The forth compilation is shape 32, so it should have kv_caches and the
content
=
f
.
read
()
# ragged_paged_attention
assert
(
kv_cache_prefix
in
content
and
attn_prefix
in
content
)
with
open
(
compiled_fns
[
3
])
as
f
:
content
=
f
.
read
()
# The forth compilation is shape 32, so it should have kv_caches and the
assert
(
kv_cache_prefix
in
content
and
attn_prefix
in
content
)
# ragged_paged_attention
with
open
(
compiled_fns
[
3
])
as
f
:
content
=
f
.
read
()
assert
(
kv_cache_prefix
in
content
and
attn_prefix
in
content
)
tests/v1/engine/test_output_processor.py
View file @
31f6b24f
...
@@ -11,11 +11,13 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
...
@@ -11,11 +11,13 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
STOP_STRINGS
,
STOP_STRINGS
,
DummyOutputProcessorTestVectors
,
DummyOutputProcessorTestVectors
,
MockEngineCore
)
MockEngineCore
)
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.sequence
import
PromptLogprobs
,
SampleLogprobs
from
vllm.sequence
import
PromptLogprobs
,
SampleLogprobs
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.output_processor
import
OutputProcessor
from
vllm.v1.engine.output_processor
import
(
OutputProcessor
,
RequestOutputCollector
)
from
vllm.v1.metrics.stats
import
IterationStats
from
vllm.v1.metrics.stats
import
IterationStats
...
@@ -834,3 +836,88 @@ def test_iteration_stats(dummy_test_vectors):
...
@@ -834,3 +836,88 @@ def test_iteration_stats(dummy_test_vectors):
assert
iteration_stats
.
num_prompt_tokens
==
0
assert
iteration_stats
.
num_prompt_tokens
==
0
assert
iteration_stats
.
num_generation_tokens
==
num_active
assert
iteration_stats
.
num_generation_tokens
==
num_active
@
pytest
.
mark
.
asyncio
async
def
test_request_output_collector
():
NUM_REQS
=
3
TEXT
=
"a"
def
make_outputs
()
->
list
[
RequestOutput
]:
return
[
RequestOutput
(
request_id
=
"my-request-id"
,
prompt
=
None
,
prompt_token_ids
=
[
1
,
2
,
3
],
prompt_logprobs
=
None
,
outputs
=
[
CompletionOutput
(
index
=
0
,
text
=
TEXT
,
token_ids
=
[
idx
],
cumulative_logprob
=
(
idx
+
1
*
1.0
),
logprobs
=
[{
"a"
:
idx
,
"b"
:
idx
}],
finish_reason
=
"length"
if
(
idx
==
NUM_REQS
-
1
)
else
None
,
)
],
finished
=
(
idx
==
NUM_REQS
-
1
),
)
for
idx
in
range
(
NUM_REQS
)
]
collector
=
RequestOutputCollector
(
RequestOutputKind
.
DELTA
)
# CASE 1: Put then get.
outputs
=
make_outputs
()
collector
.
put
(
outputs
[
0
])
output
=
await
collector
.
get
()
assert
not
collector
.
ready
.
is_set
()
assert
collector
.
output
is
None
assert
output
.
outputs
[
0
].
text
==
"a"
assert
output
.
outputs
[
0
].
token_ids
==
[
0
]
# CASE 2: 2 puts then get.
num_to_put
=
2
outputs
=
make_outputs
()
for
i
in
range
(
num_to_put
):
collector
.
put
(
outputs
[
i
])
output
=
await
collector
.
get
()
assert
not
collector
.
ready
.
is_set
()
assert
collector
.
output
is
None
assert
not
output
.
finished
# Text, token_ids, and logprobs should get merged.
assert
output
.
outputs
[
0
].
text
==
TEXT
*
num_to_put
for
tok_0
,
tok_1
in
zip
(
output
.
outputs
[
0
].
token_ids
,
list
(
range
(
num_to_put
))):
assert
tok_0
==
tok_1
assert
len
(
output
.
outputs
[
0
].
logprobs
)
==
num_to_put
# Cumulative logprobs should be the last one.
cumulative_logprob_expected
=
1.0
*
num_to_put
assert
output
.
outputs
[
0
].
cumulative_logprob
==
cumulative_logprob_expected
# CASE 3: Put all 3 (including a finished).
num_to_put
=
3
outputs
=
make_outputs
()
for
i
in
range
(
num_to_put
):
collector
.
put
(
outputs
[
i
])
output
=
await
collector
.
get
()
assert
not
collector
.
ready
.
is_set
()
assert
collector
.
output
is
None
assert
output
.
finished
assert
output
.
outputs
[
0
].
finish_reason
==
"length"
# Text, token_ids, and logprobs should get merged.
assert
output
.
outputs
[
0
].
text
==
TEXT
*
num_to_put
for
tok_0
,
tok_1
in
zip
(
output
.
outputs
[
0
].
token_ids
,
list
(
range
(
num_to_put
))):
assert
tok_0
==
tok_1
assert
len
(
output
.
outputs
[
0
].
logprobs
)
==
num_to_put
# Cumulative logprobs should be the last one.
cumulative_logprob_expected
=
1.0
*
num_to_put
assert
output
.
outputs
[
0
].
cumulative_logprob
==
cumulative_logprob_expected
tests/v1/entrypoints/llm/test_struct_output_generate.py
View file @
31f6b24f
...
@@ -13,7 +13,7 @@ from vllm.entrypoints.llm import LLM
...
@@ -13,7 +13,7 @@ from vllm.entrypoints.llm import LLM
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
GuidedDecodingParams
,
SamplingParams
from
vllm.sampling_params
import
GuidedDecodingParams
,
SamplingParams
GUIDED_DECODING_BACKENDS_V1
=
[
"xgrammar"
]
GUIDED_DECODING_BACKENDS_V1
=
[
"xgrammar"
,
"guidance"
]
MODELS_TO_TEST
=
[
MODELS_TO_TEST
=
[
"Qwen/Qwen2.5-1.5B-Instruct"
,
"mistralai/Ministral-8B-Instruct-2410"
"Qwen/Qwen2.5-1.5B-Instruct"
,
"mistralai/Ministral-8B-Instruct-2410"
]
]
...
@@ -30,12 +30,13 @@ def test_guided_json_completion(
...
@@ -30,12 +30,13 @@ def test_guided_json_completion(
model_name
:
str
,
model_name
:
str
,
):
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
llm
=
LLM
(
model
=
model_name
,
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
max_model_len
=
1024
,
max_tokens
=
1000
,
guided_decoding_backend
=
guided_decoding_backend
)
guided_decoding
=
GuidedDecodingParams
(
sampling_params
=
SamplingParams
(
json
=
sample_json_schema
,
temperature
=
1.0
,
backend
=
guided_decoding_backend
))
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
json
=
sample_json_schema
))
outputs
=
llm
.
generate
(
prompts
=
[
outputs
=
llm
.
generate
(
prompts
=
[
f
"Give an example JSON for an employee profile "
f
"Give an example JSON for an employee profile "
f
"that fits this schema:
{
sample_json_schema
}
"
f
"that fits this schema:
{
sample_json_schema
}
"
...
@@ -111,13 +112,14 @@ def test_guided_json_object(
...
@@ -111,13 +112,14 @@ def test_guided_json_object(
model_name
:
str
,
model_name
:
str
,
):
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
llm
=
LLM
(
model
=
model_name
,
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
max_model_len
=
1024
,
max_tokens
=
100
,
guided_decoding_backend
=
guided_decoding_backend
)
n
=
2
,
sampling_params
=
SamplingParams
(
guided_decoding
=
GuidedDecodingParams
(
temperature
=
1.0
,
json_object
=
True
,
max_tokens
=
100
,
backend
=
guided_decoding_backend
))
n
=
2
,
guided_decoding
=
GuidedDecodingParams
(
json_object
=
True
))
outputs
=
llm
.
generate
(
outputs
=
llm
.
generate
(
prompts
=
(
"Generate a JSON object with curly braces for a person with "
prompts
=
(
"Generate a JSON object with curly braces for a person with "
...
@@ -137,12 +139,20 @@ def test_guided_json_object(
...
@@ -137,12 +139,20 @@ def test_guided_json_object(
# Parse to verify it is valid JSON
# Parse to verify it is valid JSON
parsed_json
=
json
.
loads
(
generated_text
)
parsed_json
=
json
.
loads
(
generated_text
)
assert
isinstance
(
parsed_json
,
dict
)
allowed_types
:
tuple
[
type
,
...]
=
(
dict
,
)
if
guided_decoding_backend
==
"xgrammar"
:
# TODO - we are currently too permissive with xgrammar and
# allow # any valid json (typically comes back as a list or
# object). We can fix this by specifying a jsonschema of
# {"type": "object"}, # but we need this fix in a release
# first: https://github.com/mlc-ai/xgrammar/pull/264
allowed_types
=
(
dict
,
list
)
assert
isinstance
(
parsed_json
,
allowed_types
)
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS_V1
)
GUIDED_DECODING_BACKENDS_V1
+
[
"auto"
]
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
MODELS_TO_TEST
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
MODELS_TO_TEST
)
def
test_guided_json_unsupported_schema
(
def
test_guided_json_unsupported_schema
(
monkeypatch
:
pytest
.
MonkeyPatch
,
monkeypatch
:
pytest
.
MonkeyPatch
,
...
@@ -151,21 +161,43 @@ def test_guided_json_unsupported_schema(
...
@@ -151,21 +161,43 @@ def test_guided_json_unsupported_schema(
model_name
:
str
,
model_name
:
str
,
):
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
llm
=
LLM
(
model
=
model_name
,
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
max_model_len
=
1024
,
max_tokens
=
1000
,
guided_decoding_backend
=
guided_decoding_backend
)
guided_decoding
=
GuidedDecodingParams
(
sampling_params
=
SamplingParams
(
json
=
unsupported_json_schema
,
temperature
=
1.0
,
backend
=
guided_decoding_backend
))
max_tokens
=
1000
,
with
pytest
.
raises
(
ValueError
,
guided_decoding
=
GuidedDecodingParams
(
json
=
unsupported_json_schema
))
match
=
"The provided JSON schema contains features "
if
guided_decoding_backend
==
"xgrammar"
:
"not supported by xgrammar."
):
with
pytest
.
raises
(
ValueError
,
llm
.
generate
(
prompts
=
[
match
=
"The provided JSON schema contains features "
f
"Give an example JSON for an employee profile "
"not supported by xgrammar."
):
f
"that fits this schema:
{
unsupported_json_schema
}
"
llm
.
generate
(
prompts
=
[
]
*
2
,
f
"Give an example JSON for an employee profile "
sampling_params
=
sampling_params
,
f
"that fits this schema:
{
unsupported_json_schema
}
"
use_tqdm
=
True
)
]
*
2
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
)
else
:
# This should work for both "guidance" and "auto".
outputs
=
llm
.
generate
(
prompts
=
(
"Give an example JSON object for a grade "
"that fits this schema: "
f
"
{
unsupported_json_schema
}
"
),
sampling_params
=
sampling_params
,
use_tqdm
=
True
)
assert
outputs
is
not
None
for
output
in
outputs
:
assert
output
is
not
None
assert
isinstance
(
output
,
RequestOutput
)
generated_text
=
output
.
outputs
[
0
].
text
assert
generated_text
is
not
None
print
(
generated_text
)
# Parse to verify it is valid JSON
parsed_json
=
json
.
loads
(
generated_text
)
assert
isinstance
(
parsed_json
,
dict
)
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
...
@@ -179,13 +211,14 @@ def test_guided_grammar_ebnf(
...
@@ -179,13 +211,14 @@ def test_guided_grammar_ebnf(
model_name
:
str
,
model_name
:
str
,
):
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
llm
=
LLM
(
model
=
model_name
,
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
max_model_len
=
1024
,
top_p
=
0.95
,
guided_decoding_backend
=
guided_decoding_backend
)
max_tokens
=
1000
,
sampling_params
=
SamplingParams
(
guided_decoding
=
GuidedDecodingParams
(
temperature
=
0.8
,
grammar
=
sample_sql_ebnf
,
top_p
=
0.95
,
backend
=
guided_decoding_backend
))
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
grammar
=
sample_sql_ebnf
))
outputs
=
llm
.
generate
(
outputs
=
llm
.
generate
(
prompts
=
(
"Generate a sql statement that selects col_1 from "
prompts
=
(
"Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"
),
"table_1 where it is equal to 1"
),
...
@@ -222,13 +255,14 @@ def test_guided_grammar_lark(
...
@@ -222,13 +255,14 @@ def test_guided_grammar_lark(
model_name
:
str
,
model_name
:
str
,
):
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
llm
=
LLM
(
model
=
model_name
,
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
max_model_len
=
1024
,
top_p
=
0.95
,
guided_decoding_backend
=
guided_decoding_backend
)
max_tokens
=
1000
,
sampling_params
=
SamplingParams
(
guided_decoding
=
GuidedDecodingParams
(
temperature
=
0.8
,
grammar
=
sample_sql_lark
,
top_p
=
0.95
,
backend
=
guided_decoding_backend
))
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
grammar
=
sample_sql_lark
))
outputs
=
llm
.
generate
(
outputs
=
llm
.
generate
(
prompts
=
(
"Generate a sql statement that selects col_1 from "
prompts
=
(
"Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"
),
"table_1 where it is equal to 1"
),
...
@@ -269,16 +303,15 @@ def test_guided_grammar_ebnf_invalid(
...
@@ -269,16 +303,15 @@ def test_guided_grammar_ebnf_invalid(
model_name
:
str
,
model_name
:
str
,
):
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
llm
=
LLM
(
model
=
model_name
,
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
max_model_len
=
1024
,
top_p
=
0.95
,
guided_decoding_backend
=
guided_decoding_backend
)
max_tokens
=
1000
,
sampling_params
=
SamplingParams
(
guided_decoding
=
GuidedDecodingParams
(
temperature
=
0.8
,
grammar
=
"not a grammar"
,
top_p
=
0.95
,
backend
=
guided_decoding_backend
))
max_tokens
=
1000
,
with
pytest
.
raises
(
ValueError
,
guided_decoding
=
GuidedDecodingParams
(
grammar
=
"not a grammar"
))
match
=
"Failed to convert the grammar "
with
pytest
.
raises
(
ValueError
,
match
=
"Failed to convert the grammar "
):
"from Lark to EBNF."
):
llm
.
generate
(
llm
.
generate
(
prompts
=
(
"Generate a sql statement that selects col_1 from "
prompts
=
(
"Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"
),
"table_1 where it is equal to 1"
),
...
@@ -298,12 +331,13 @@ def test_guided_regex(
...
@@ -298,12 +331,13 @@ def test_guided_regex(
model_name
:
str
,
model_name
:
str
,
):
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
llm
=
LLM
(
model
=
model_name
,
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
max_model_len
=
1024
,
top_p
=
0.95
,
guided_decoding_backend
=
guided_decoding_backend
)
guided_decoding
=
GuidedDecodingParams
(
sampling_params
=
SamplingParams
(
regex
=
sample_regex
,
temperature
=
0.8
,
backend
=
guided_decoding_backend
))
top_p
=
0.95
,
guided_decoding
=
GuidedDecodingParams
(
regex
=
sample_regex
))
outputs
=
llm
.
generate
(
outputs
=
llm
.
generate
(
prompts
=
[
prompts
=
[
f
"Give an example IPv4 address with this regex:
{
sample_regex
}
"
f
"Give an example IPv4 address with this regex:
{
sample_regex
}
"
...
@@ -335,12 +369,13 @@ def test_guided_choice_completion(
...
@@ -335,12 +369,13 @@ def test_guided_choice_completion(
model_name
:
str
,
model_name
:
str
,
):
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
llm
=
LLM
(
model
=
model_name
,
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
max_model_len
=
1024
,
top_p
=
0.95
,
guided_decoding_backend
=
guided_decoding_backend
)
guided_decoding
=
GuidedDecodingParams
(
sampling_params
=
SamplingParams
(
choice
=
sample_guided_choice
,
temperature
=
0.8
,
backend
=
guided_decoding_backend
))
top_p
=
0.95
,
guided_decoding
=
GuidedDecodingParams
(
choice
=
sample_guided_choice
))
outputs
=
llm
.
generate
(
outputs
=
llm
.
generate
(
prompts
=
"The best language for type-safe systems programming is "
,
prompts
=
"The best language for type-safe systems programming is "
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
...
...
tests/v1/sample/test_rejection_sampler.py
View file @
31f6b24f
...
@@ -36,6 +36,8 @@ def create_logits_tensor(output_token_ids: list[list[int]],
...
@@ -36,6 +36,8 @@ def create_logits_tensor(output_token_ids: list[list[int]],
def
create_sampling_metadata
(
def
create_sampling_metadata
(
all_greedy
:
bool
,
all_greedy
:
bool
,
temperature
:
Optional
[
torch
.
Tensor
]
=
None
,
temperature
:
Optional
[
torch
.
Tensor
]
=
None
,
top_k
:
Optional
[
torch
.
Tensor
]
=
None
,
top_p
:
Optional
[
torch
.
Tensor
]
=
None
,
generators
:
Optional
[
dict
[
int
,
Any
]]
=
None
,
generators
:
Optional
[
dict
[
int
,
Any
]]
=
None
,
)
->
SamplingMetadata
:
)
->
SamplingMetadata
:
"""Create a v1 sampling metadata object with all_greedy set
"""Create a v1 sampling metadata object with all_greedy set
...
@@ -52,8 +54,8 @@ def create_sampling_metadata(
...
@@ -52,8 +54,8 @@ def create_sampling_metadata(
temperature
=
temperature
,
temperature
=
temperature
,
all_greedy
=
all_greedy
,
all_greedy
=
all_greedy
,
all_random
=
not
all_greedy
,
all_random
=
not
all_greedy
,
top_p
=
None
,
top_p
=
top_p
,
top_k
=
None
,
top_k
=
top_k
,
min_p
=
torch
.
empty
(
1
,
),
min_p
=
torch
.
empty
(
1
,
),
generators
=
generators
,
generators
=
generators
,
max_num_logprobs
=
0
,
max_num_logprobs
=
0
,
...
@@ -462,3 +464,147 @@ def estimate_rejection_sampling_pdf(
...
@@ -462,3 +464,147 @@ def estimate_rejection_sampling_pdf(
density
=
True
)
density
=
True
)
return
hist
.
hist
return
hist
.
hist
def
_test_masked_logits
(
rejection_sampler
,
batch_size
:
int
,
num_draft_tokens
:
int
,
vocab_size
:
int
,
target_logits
:
torch
.
Tensor
,
unmasked_indices
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
):
# Set up test parameters
num_tokens
=
batch_size
*
num_draft_tokens
# Create random draft probabilities.
draft_probs
=
torch
.
rand
((
num_tokens
,
vocab_size
),
dtype
=
torch
.
float32
,
device
=
DEVICE
)
draft_probs
=
F
.
softmax
(
draft_probs
,
dim
=-
1
)
# Randomly sample draft token ids from draft probs
draft_token_ids
=
torch
.
multinomial
(
draft_probs
,
num_samples
=
1
)
draft_token_ids
=
draft_token_ids
.
reshape
(
batch_size
,
num_draft_tokens
)
draft_token_ids
=
draft_token_ids
.
tolist
()
# Bonus tokens not used but required
bonus_token_ids
=
torch
.
zeros
((
batch_size
,
1
),
dtype
=
torch
.
int64
,
device
=
DEVICE
)
# Create spec decode metadata
spec_decode_metadata
=
SpecDecodeMetadata
.
make_dummy
(
draft_token_ids
,
device
=
DEVICE
,
)
# Run rejection sampling
output_token_ids
=
rejection_sampler
(
spec_decode_metadata
,
draft_probs
=
draft_probs
,
target_logits
=
target_logits
,
bonus_token_ids
=
bonus_token_ids
,
sampling_metadata
=
sampling_metadata
,
)
# Remove bonus tokens and reshape
output_token_ids
=
output_token_ids
[:,
:
-
1
].
flatten
().
tolist
()
# Check that all sampled tokens are within the unmasked indices.
for
i
in
range
(
num_tokens
):
token_id
=
output_token_ids
[
i
]
if
token_id
==
PLACEHOLDER_TOKEN_ID
:
continue
assert
token_id
in
unmasked_indices
[
i
]
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
1
,
5
,
99
])
def
test_top_k
(
rejection_sampler
,
top_k
):
"""Test rejection sampling with top-k sampling"""
vocab_size
=
100
batch_size
=
100
num_draft_tokens
=
3
num_tokens
=
batch_size
*
num_draft_tokens
# Randomly create top-k indices.
top_k_indices
=
[
torch
.
randperm
(
vocab_size
,
device
=
DEVICE
)[:
top_k
]
for
_
in
range
(
num_tokens
)
]
top_k_indices
=
torch
.
stack
(
top_k_indices
)
# Create logits with the uniform distribution.
target_logits
=
torch
.
zeros
((
num_tokens
,
vocab_size
),
device
=
DEVICE
)
# Increment the logits for top-k indices, a little bit more than the other
# ones. If the masking is effective, the non-topk indices will never be
# sampled despite the small difference in logits.
for
i
in
range
(
num_tokens
):
target_logits
[
i
,
top_k_indices
[
i
]]
+=
0.1
# Create sampling metadata
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
sampling_metadata
=
create_sampling_metadata
(
all_greedy
=
False
,
temperature
=
temperature
,
top_k
=
torch
.
tensor
([
top_k
]
*
batch_size
,
device
=
DEVICE
,
dtype
=
torch
.
int64
),
)
_test_masked_logits
(
rejection_sampler
,
batch_size
=
batch_size
,
num_draft_tokens
=
num_draft_tokens
,
vocab_size
=
vocab_size
,
target_logits
=
target_logits
,
unmasked_indices
=
top_k_indices
,
sampling_metadata
=
sampling_metadata
,
)
@
pytest
.
mark
.
parametrize
(
"top_p"
,
[
0.5
,
0.9
,
0.99
])
def
test_top_p
(
rejection_sampler
,
top_p
):
"""Test rejection sampling with top-p sampling"""
vocab_size
=
100
batch_size
=
100
num_draft_tokens
=
3
num_tokens
=
batch_size
*
num_draft_tokens
# Create logits with the uniform distribution.
target_logits
=
torch
.
randn
((
num_tokens
,
vocab_size
),
device
=
DEVICE
)
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
rescaled_logits
=
target_logits
/
temperature
logits_sort
,
logits_idx
=
rescaled_logits
.
sort
(
dim
=-
1
,
descending
=
False
)
probs_sort
=
logits_sort
.
softmax
(
dim
=-
1
)
probs_sum
=
probs_sort
.
cumsum
(
dim
=-
1
)
top_p_mask
=
probs_sum
<=
1
-
top_p
# at least one
top_p_mask
[:,
-
1
]
=
False
# Get the top-p indices.
top_p_indices
=
[]
for
i
in
range
(
num_tokens
):
top_p_indices
.
append
(
logits_idx
[
i
][
~
top_p_mask
[
i
]].
tolist
())
# Create sampling metadata
sampling_metadata
=
create_sampling_metadata
(
all_greedy
=
False
,
temperature
=
temperature
,
top_p
=
torch
.
tensor
([
top_p
]
*
batch_size
,
device
=
DEVICE
,
dtype
=
torch
.
float32
),
)
_test_masked_logits
(
rejection_sampler
,
batch_size
=
batch_size
,
num_draft_tokens
=
num_draft_tokens
,
vocab_size
=
vocab_size
,
target_logits
=
target_logits
,
unmasked_indices
=
top_p_indices
,
sampling_metadata
=
sampling_metadata
,
)
vllm/attention/backends/flash_attn.py
View file @
31f6b24f
...
@@ -22,12 +22,13 @@ from vllm.attention.backends.utils import (
...
@@ -22,12 +22,13 @@ from vllm.attention.backends.utils import (
compute_slot_mapping_start_idx
,
get_num_prefill_decode_query_kv_tokens
,
compute_slot_mapping_start_idx
,
get_num_prefill_decode_query_kv_tokens
,
get_seq_len_block_table_args
,
is_all_cross_attn_metadata_set
,
get_seq_len_block_table_args
,
is_all_cross_attn_metadata_set
,
is_all_encoder_attn_metadata_set
,
is_block_tables_empty
)
is_all_encoder_attn_metadata_set
,
is_block_tables_empty
)
from
vllm.fa_utils
import
get_flash_attn_version
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
)
flash_attn_with_kvcache
)
from
vllm.vllm_flash_attn.fa_utils
import
(
flash_attn_supports_fp8
,
get_flash_attn_version
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
...
@@ -632,10 +633,13 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -632,10 +633,13 @@ class FlashAttentionImpl(AttentionImpl):
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
vllm_flash_attn_version
=
get_flash_attn_version
(
self
.
vllm_flash_attn_version
=
get_flash_attn_version
(
requires_alibi
=
self
.
alibi_slopes
is
not
None
)
requires_alibi
=
self
.
alibi_slopes
is
not
None
)
if
(
is_quantized_kv_cache
(
self
.
kv_cache_dtype
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
)
and
(
and
self
.
vllm_flash_attn_version
!=
3
):
not
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
or
not
flash_attn_supports_fp8
()):
raise
NotImplementedError
(
raise
NotImplementedError
(
"Only FlashAttention3 supports FP8 KV cache"
)
f
"FlashAttention does not support
{
self
.
kv_cache_dtype
}
"
"kv-cache on this device "
f
"(FA supports fp8 =
{
flash_attn_supports_fp8
()
}
)."
)
if
logits_soft_cap
is
None
:
if
logits_soft_cap
is
None
:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap
=
0
logits_soft_cap
=
0
...
@@ -704,6 +708,10 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -704,6 +708,10 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap
:
Optional
[
float
]
=
self
.
logits_soft_cap
logits_soft_cap
:
Optional
[
float
]
=
self
.
logits_soft_cap
fp8_attention
=
kv_cache_dtype
.
startswith
(
"fp8"
)
fp8_attention
=
kv_cache_dtype
.
startswith
(
"fp8"
)
if
fp8_attention
and
not
flash_attn_supports_fp8
():
raise
NotImplementedError
(
"FlashAttention does not support FP8 kv-cache on this device."
)
if
kv_cache
.
numel
()
>
0
:
if
kv_cache
.
numel
()
>
0
:
key_cache
=
kv_cache
[
0
]
key_cache
=
kv_cache
[
0
]
value_cache
=
kv_cache
[
1
]
value_cache
=
kv_cache
[
1
]
...
...
vllm/attention/backends/mla/common.py
View file @
31f6b24f
...
@@ -206,7 +206,6 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
...
@@ -206,7 +206,6 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
is_block_tables_empty
)
from
vllm.attention.ops.triton_merge_attn_states
import
merge_attn_states
from
vllm.attention.ops.triton_merge_attn_states
import
merge_attn_states
from
vllm.fa_utils
import
get_flash_attn_version
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
RowParallelLinear
,
LinearBase
,
RowParallelLinear
,
UnquantizedLinearMethod
)
UnquantizedLinearMethod
)
...
@@ -215,6 +214,7 @@ from vllm.model_executor.layers.rotary_embedding import (
...
@@ -215,6 +214,7 @@ from vllm.model_executor.layers.rotary_embedding import (
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
async_tensor_h2d
,
cdiv
,
make_tensor_with_pad
,
round_down
from
vllm.utils
import
async_tensor_h2d
,
cdiv
,
make_tensor_with_pad
,
round_down
from
vllm.vllm_flash_attn.fa_utils
import
get_flash_attn_version
try
:
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
...
vllm/compilation/inductor_pass.py
View file @
31f6b24f
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
hashlib
import
hashlib
import
importlib.metadata
import
inspect
import
inspect
import
json
import
types
import
types
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
torch
import
torch
from
packaging.version
import
Version
from
torch
import
fx
from
torch
import
fx
if
Version
(
importlib
.
metadata
.
version
(
'torch'
))
>=
Version
(
"2.6"
):
from
torch._inductor.custom_graph_pass
import
CustomGraphPass
else
:
# CustomGraphPass is not present in 2.5 or lower, import our version
from
.torch25_custom_graph_pass
import
(
# noqa: yapf
Torch25CustomGraphPass
as
CustomGraphPass
)
class
InductorPass
(
ABC
):
class
InductorPass
(
CustomGraphPass
):
"""
"""
General custom inductor pass interface.
A custom graph pass that uses a hash of its source as the UUID.
This is defined as a convenience and should work in most cases.
"""
"""
@
abstractmethod
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
"""
Execute the pass on the given graph.
"""
raise
NotImplementedError
def
uuid
(
self
)
->
Any
:
def
uuid
(
self
)
->
Any
:
"""
"""
Provide a unique identifier for the pass, used in Inductor code cache.
Provide a unique identifier for the pass, used in Inductor code cache.
...
@@ -48,7 +51,16 @@ class InductorPass(ABC):
...
@@ -48,7 +51,16 @@ class InductorPass(ABC):
else
:
else
:
src_str
=
inspect
.
getsource
(
src
.
__class__
)
src_str
=
inspect
.
getsource
(
src
.
__class__
)
hasher
.
update
(
src_str
.
encode
(
"utf-8"
))
hasher
.
update
(
src_str
.
encode
(
"utf-8"
))
return
hasher
.
digest
()
return
hasher
.
hexdigest
()
@
staticmethod
def
hash_dict
(
dict_
:
Dict
[
Any
,
Any
]):
"""
Utility method to hash a dictionary, can alternatively be used for uuid.
:return: A sha256 hash of the json rep of the dictionary.
"""
encoded
=
json
.
dumps
(
dict_
,
sort_keys
=
True
).
encode
(
"utf-8"
)
return
hashlib
.
sha256
(
encoded
).
hexdigest
()
class
CallableInductorPass
(
InductorPass
):
class
CallableInductorPass
(
InductorPass
):
...
@@ -61,25 +73,10 @@ class CallableInductorPass(InductorPass):
...
@@ -61,25 +73,10 @@ class CallableInductorPass(InductorPass):
callable
:
Callable
[[
fx
.
Graph
],
None
],
callable
:
Callable
[[
fx
.
Graph
],
None
],
uuid
:
Optional
[
Any
]
=
None
):
uuid
:
Optional
[
Any
]
=
None
):
self
.
callable
=
callable
self
.
callable
=
callable
if
uuid
is
None
:
self
.
_uuid
=
self
.
hash_source
(
callable
)
if
uuid
is
None
else
uuid
uuid
=
InductorPass
.
hash_source
(
callable
)
self
.
_uuid
=
uuid
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
self
.
callable
(
graph
)
self
.
callable
(
graph
)
def
uuid
(
self
)
->
Any
:
def
uuid
(
self
)
->
Any
:
return
self
.
_uuid
return
self
.
_uuid
def
__getstate__
(
self
):
"""
Pickling occurs in the Inductor code cache if a pass is not given to
the pass manager but is instead directly added to config as a pass.
See PostGradPassManager for more.
TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
"""
return
self
.
_uuid
def
__setstate__
(
self
,
state
):
raise
ValueError
(
"Cannot unpickle CallableInductorPass"
)
vllm/compilation/pass_manager.py
View file @
31f6b24f
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Dict
,
List
from
typing
import
List
import
torch
from
torch
import
fx
as
fx
from
torch
import
fx
as
fx
from
vllm.config
import
CompilationConfig
from
vllm.config
import
CompilationConfig
...
@@ -10,29 +9,18 @@ from vllm.logger import init_logger
...
@@ -10,29 +9,18 @@ from vllm.logger import init_logger
from
.fix_functionalization
import
FixFunctionalizationPass
from
.fix_functionalization
import
FixFunctionalizationPass
from
.fusion
import
FusionPass
from
.fusion
import
FusionPass
from
.inductor_pass
import
InductorPass
from
.inductor_pass
import
CustomGraphPass
,
InductorPass
from
.noop_elimination
import
NoOpEliminationPass
from
.noop_elimination
import
NoOpEliminationPass
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
PlaceHolder
:
class
PostGradPassManager
(
CustomGraphPass
):
pass
if
torch
.
__version__
<
"2.6"
:
Parent
=
PlaceHolder
# type: ignore
else
:
Parent
=
torch
.
_inductor
.
custom_graph_pass
.
CustomGraphPass
# type: ignore
class
PostGradPassManager
(
Parent
):
"""
"""
The pass manager for post-grad passes.
The pass manager for post-grad passes.
It handles configuration, adding custom passes, and running passes.
It handles configuration, adding custom passes, and running passes.
It also supports pickling, which is used by the Inductor code cache.
It supports uuid for the Inductor code cache. That includes torch<2.6
TODO(torch==2.6), use CustomGraphPass
support using pickling (in .inductor_pass.CustomGraphPass).
(torch._inductor.custom_graph_pass.CustomGraphPass)
The order of the post-grad post-passes is:
The order of the post-grad post-passes is:
1. passes (constructor parameter)
1. passes (constructor parameter)
...
@@ -67,27 +55,13 @@ class PostGradPassManager(Parent):
...
@@ -67,27 +55,13 @@ class PostGradPassManager(Parent):
self
.
passes
.
append
(
pass_
)
self
.
passes
.
append
(
pass_
)
def
uuid
(
self
):
def
uuid
(
self
):
return
self
.
__getstate__
()
def
__getstate__
(
self
)
->
Dict
[
str
,
List
[
Any
]]:
"""
"""
Custom pickling for the pass manager, as some passes cannot be pickled.
The PostGradPassManager is set as a custom pass in the Inductor and
Pickling occurs because the pass manager is set as the value of
affects compilation caching. Its uuid depends on the UUIDs of all
`config["post_grad_custom_post_pass"]` in the Inductor config.
dependent passes and the pass config. See InductorPass for more info.
The config is pickled to act as a key in the Inductor code cache.
Any other passes in the config are pickled as well.
TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
"""
"""
state
=
{
"pass_config"
:
self
.
pass_config
.
uuid
(),
"passes"
:
[]}
state
=
{
"pass_config"
:
self
.
pass_config
.
uuid
(),
"passes"
:
[]}
for
pass_
in
self
.
passes
:
for
pass_
in
self
.
passes
:
state
[
"passes"
].
append
(
pass_
.
uuid
())
state
[
"passes"
].
append
(
pass_
.
uuid
())
state
[
"passes"
].
append
(
self
.
fix_functionalization
.
uuid
())
state
[
"passes"
].
append
(
self
.
fix_functionalization
.
uuid
())
return
state
return
InductorPass
.
hash_dict
(
state
)
def
__setstate__
(
self
,
state
):
"""
Do not allow unpickling of the pass manager.
If this is needed in the future, it should properly pickle the passes.
"""
raise
ValueError
(
"Cannot unpickle PostGradPassManager"
)
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment