Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
31f6b24f
Commit
31f6b24f
authored
Mar 26, 2025
by
zhuwenwen
Browse files
Merge remote-tracking branch 'mirror/v0.8.2' into v0.8.2-ori
parents
89d1dd57
25f560a6
Changes
88
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
780 additions
and
255 deletions
+780
-255
tests/build_cython.py
tests/build_cython.py
+38
-0
tests/compile/test_pass_manager.py
tests/compile/test_pass_manager.py
+46
-16
tests/distributed/test_pipeline_parallel.py
tests/distributed/test_pipeline_parallel.py
+3
-0
tests/entrypoints/openai/test_chat_template.py
tests/entrypoints/openai/test_chat_template.py
+2
-0
tests/entrypoints/openai/test_video.py
tests/entrypoints/openai/test_video.py
+2
-2
tests/entrypoints/test_chat_utils.py
tests/entrypoints/test_chat_utils.py
+62
-2
tests/fastsafetensors_loader/__init__.py
tests/fastsafetensors_loader/__init__.py
+0
-0
tests/fastsafetensors_loader/test_fastsafetensors_loader.py
tests/fastsafetensors_loader/test_fastsafetensors_loader.py
+22
-0
tests/fastsafetensors_loader/test_weight_utils.py
tests/fastsafetensors_loader/test_weight_utils.py
+46
-0
tests/kernels/test_marlin_gemm.py
tests/kernels/test_marlin_gemm.py
+45
-0
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+35
-10
tests/tool_use/utils.py
tests/tool_use/utils.py
+4
-1
tests/tpu/test_compilation.py
tests/tpu/test_compilation.py
+90
-86
tests/v1/engine/test_output_processor.py
tests/v1/engine/test_output_processor.py
+88
-1
tests/v1/entrypoints/llm/test_struct_output_generate.py
tests/v1/entrypoints/llm/test_struct_output_generate.py
+102
-67
tests/v1/sample/test_rejection_sampler.py
tests/v1/sample/test_rejection_sampler.py
+148
-2
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+12
-4
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+1
-1
vllm/compilation/inductor_pass.py
vllm/compilation/inductor_pass.py
+25
-28
vllm/compilation/pass_manager.py
vllm/compilation/pass_manager.py
+9
-35
No files found.
tests/build_cython.py
0 → 100644
View file @
31f6b24f
# SPDX-License-Identifier: Apache-2.0
import
Cython.Compiler.Options
from
Cython.Build
import
cythonize
from
setuptools
import
setup
Cython
.
Compiler
.
Options
.
annotate
=
True
infiles
=
[]
infiles
+=
[
"vllm/engine/llm_engine.py"
,
"vllm/transformers_utils/detokenizer.py"
,
"vllm/engine/output_processor/single_step.py"
,
"vllm/outputs.py"
,
"vllm/engine/output_processor/stop_checker.py"
,
]
infiles
+=
[
"vllm/core/scheduler.py"
,
"vllm/sequence.py"
,
"vllm/core/block_manager.py"
,
]
infiles
+=
[
"vllm/model_executor/layers/sampler.py"
,
"vllm/sampling_params.py"
,
"vllm/utils.py"
,
]
setup
(
ext_modules
=
cythonize
(
infiles
,
annotate
=
False
,
force
=
True
,
compiler_directives
=
{
'language_level'
:
"3"
,
'infer_types'
:
True
}))
# example usage: python3 build_cython.py build_ext --inplace
tests/compile/test_pass_manager.py
View file @
31f6b24f
# SPDX-License-Identifier: Apache-2.0
import
pickle
import
copy
import
pytest
import
torch
...
...
@@ -10,32 +9,63 @@ from vllm.compilation.pass_manager import PostGradPassManager
from
vllm.config
import
CompilationConfig
# dummy custom pass that doesn't inherit
def
simple_callable
(
graph
:
torch
.
fx
.
Graph
):
pass
callable_uuid
=
CallableInductorPass
(
simple_callable
,
InductorPass
.
hash_source
(
__file__
))
# Should fail to add directly to the pass manager
def
test_bad_callable
():
config
=
CompilationConfig
().
pass_config
pass_manager
=
PostGradPassManager
()
pass_manager
.
configure
(
config
)
with
pytest
.
raises
(
AssertionError
):
pass_manager
.
add
(
simple_callable
)
# noqa, type wrong on purpose
# Pass that inherits from InductorPass
class
ProperPass
(
InductorPass
):
def
__call__
(
self
,
graph
:
torch
.
fx
.
graph
.
Graph
)
->
None
:
pass
@
pytest
.
mark
.
parametrize
(
"
works,
callable"
,
"callable"
,
[
(
False
,
simple_callable
),
(
True
,
callable_uuid
),
(
True
,
CallableInductorPass
(
simple_callable
)),
ProperPass
(),
# Can also wrap callables in CallableInductorPass for compliance
CallableInductorPass
(
simple_callable
),
CallableInductorPass
(
simple_callable
,
InductorPass
.
hash_source
(
__file__
))
],
)
def
test_pass_manager
(
works
:
bool
,
callable
):
def
test_pass_manager
_uuid
(
callable
):
config
=
CompilationConfig
().
pass_config
pass_manager
=
PostGradPassManager
()
pass_manager
.
configure
(
config
)
# Try to add the callable to the pass manager
if
works
:
pass_manager
.
add
(
callable
)
pickle
.
dumps
(
pass_manager
)
else
:
with
pytest
.
raises
(
AssertionError
):
pass_manager
.
add
(
callable
)
# Check that UUID is different if the same pass is added 2x
pass_manager
.
add
(
callable
)
uuid1
=
pass_manager
.
uuid
()
pass_manager
.
add
(
callable
)
uuid2
=
pass_manager
.
uuid
()
assert
uuid1
!=
uuid2
# UUID should be the same as the original one,
# as we constructed in the same way.
pass_manager2
=
PostGradPassManager
()
pass_manager2
.
configure
(
config
)
pass_manager2
.
add
(
callable
)
assert
uuid1
==
pass_manager2
.
uuid
()
# UUID should be different due to config change
config2
=
copy
.
deepcopy
(
config
)
config2
.
enable_fusion
=
not
config2
.
enable_fusion
pass_manager3
=
PostGradPassManager
()
pass_manager3
.
configure
(
config2
)
pass_manager3
.
add
(
callable
)
assert
uuid1
!=
pass_manager3
.
uuid
()
tests/distributed/test_pipeline_parallel.py
View file @
31f6b24f
...
...
@@ -175,6 +175,8 @@ TEXT_GENERATION_MODELS = {
"inceptionai/jais-13b-chat"
:
PPTestSettings
.
fast
(),
"ai21labs/Jamba-tiny-dev"
:
PPTestSettings
.
fast
(),
"meta-llama/Llama-3.2-1B-Instruct"
:
PPTestSettings
.
detailed
(),
# Tests TransformersModel
"ArthurZ/Ilama-3.2-1B"
:
PPTestSettings
.
fast
(),
"openbmb/MiniCPM-2B-sft-bf16"
:
PPTestSettings
.
fast
(),
"openbmb/MiniCPM3-4B"
:
PPTestSettings
.
fast
(),
# Uses Llama
...
...
@@ -243,6 +245,7 @@ TEST_MODELS = [
# [LANGUAGE GENERATION]
"microsoft/Phi-3.5-MoE-instruct"
,
"meta-llama/Llama-3.2-1B-Instruct"
,
# "ArthurZ/Ilama-3.2-1B", NOTE: Uncomment after #13905
"ibm/PowerLM-3b"
,
# [LANGUAGE EMBEDDING]
"intfloat/e5-mistral-7b-instruct"
,
...
...
tests/entrypoints/openai/test_chat_template.py
View file @
31f6b24f
...
...
@@ -107,8 +107,10 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
# Call the function and get the result
result
=
apply_hf_chat_template
(
tokenizer
,
trust_remote_code
=
True
,
conversation
=
mock_request
.
messages
,
chat_template
=
mock_request
.
chat_template
or
template_content
,
tools
=
None
,
add_generation_prompt
=
mock_request
.
add_generation_prompt
,
continue_final_message
=
mock_request
.
continue_final_message
,
)
...
...
tests/entrypoints/openai/test_video.py
View file @
31f6b24f
...
...
@@ -87,7 +87,7 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI,
choice
=
chat_completion
.
choices
[
0
]
assert
choice
.
finish_reason
==
"length"
assert
chat_completion
.
usage
==
openai
.
types
.
CompletionUsage
(
completion_tokens
=
10
,
prompt_tokens
=
62
99
,
total_tokens
=
6
309
)
completion_tokens
=
10
,
prompt_tokens
=
62
87
,
total_tokens
=
6
297
)
message
=
choice
.
message
message
=
chat_completion
.
choices
[
0
].
message
...
...
@@ -180,7 +180,7 @@ async def test_single_chat_session_video_base64encoded(
choice
=
chat_completion
.
choices
[
0
]
assert
choice
.
finish_reason
==
"length"
assert
chat_completion
.
usage
==
openai
.
types
.
CompletionUsage
(
completion_tokens
=
10
,
prompt_tokens
=
62
99
,
total_tokens
=
6
309
)
completion_tokens
=
10
,
prompt_tokens
=
62
87
,
total_tokens
=
6
297
)
message
=
choice
.
message
message
=
chat_completion
.
choices
[
0
].
message
...
...
tests/entrypoints/test_chat_utils.py
View file @
31f6b24f
...
...
@@ -4,10 +4,13 @@ import warnings
from
typing
import
Optional
import
pytest
from
packaging.version
import
Version
from
transformers
import
__version__
as
TRANSFORMERS_VERSION
from
vllm.assets.image
import
ImageAsset
from
vllm.config
import
ModelConfig
from
vllm.entrypoints.chat_utils
import
(
_try_extract_ast
,
load_chat_template
,
from
vllm.entrypoints.chat_utils
import
(
_resolve_hf_chat_template
,
_try_extract_ast
,
load_chat_template
,
parse_chat_messages
,
parse_chat_messages_futures
,
resolve_chat_template_content_format
)
...
...
@@ -23,8 +26,10 @@ EXAMPLES_DIR = VLLM_PATH / "examples"
PHI3V_MODEL_ID
=
"microsoft/Phi-3.5-vision-instruct"
ULTRAVOX_MODEL_ID
=
"fixie-ai/ultravox-v0_5-llama-3_2-1b"
QWEN2VL_MODEL_ID
=
"Qwen/Qwen2-VL-2B-Instruct"
QWEN25VL_MODEL_ID
=
"Qwen/Qwen2.5-VL-3B-Instruct"
MLLAMA_MODEL_ID
=
"meta-llama/Llama-3.2-11B-Vision-Instruct"
LLAMA_GUARD_MODEL_ID
=
"meta-llama/Llama-Guard-3-1B"
HERMES_MODEL_ID
=
"NousResearch/Hermes-3-Llama-3.1-8B"
@
pytest
.
fixture
(
scope
=
"function"
)
...
...
@@ -703,25 +708,70 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
vllm_result
=
apply_hf_chat_template
(
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
,
conversation
=
conversation
,
chat_template
=
None
,
tools
=
None
,
add_generation_prompt
=
True
,
)
assert
hf_result
==
vllm_result
@
pytest
.
mark
.
parametrize
(
"model"
,
[
QWEN2VL_MODEL_ID
,
# tokenizer.chat_template is of type str
HERMES_MODEL_ID
,
# tokenizer.chat_template is of type dict
])
@
pytest
.
mark
.
parametrize
(
"use_tools"
,
[
True
,
False
])
def
test_resolve_hf_chat_template
(
sample_json_schema
,
model
,
use_tools
):
"""checks that chat_template is a dict type for HF models."""
# Build the tokenizer group and grab the underlying tokenizer
tokenizer_group
=
TokenizerGroup
(
model
,
enable_lora
=
False
,
max_num_seqs
=
5
,
max_input_length
=
None
,
)
tokenizer
=
tokenizer_group
.
tokenizer
tools
=
[{
"type"
:
"function"
,
"function"
:
{
"name"
:
"dummy_function_name"
,
"description"
:
"This is a dummy function"
,
"parameters"
:
sample_json_schema
}
}]
if
use_tools
else
None
# Test detecting the tokenizer's chat_template
chat_template
=
_resolve_hf_chat_template
(
tokenizer
,
chat_template
=
None
,
tools
=
tools
,
trust_remote_code
=
True
,
)
assert
isinstance
(
chat_template
,
str
)
# yapf: disable
@
pytest
.
mark
.
parametrize
(
(
"model"
,
"expected_format"
),
[(
PHI3V_MODEL_ID
,
"string"
),
(
QWEN2VL_MODEL_ID
,
"openai"
),
(
QWEN25VL_MODEL_ID
,
"openai"
),
(
ULTRAVOX_MODEL_ID
,
"string"
),
(
MLLAMA_MODEL_ID
,
"openai"
),
(
LLAMA_GUARD_MODEL_ID
,
"openai"
)],
)
# yapf: enable
def
test_resolve_content_format_hf_defined
(
model
,
expected_format
):
if
model
==
QWEN25VL_MODEL_ID
and
Version
(
TRANSFORMERS_VERSION
)
<
Version
(
"4.49.0"
):
pytest
.
skip
(
"Qwen2.5-VL requires transformers>=4.49.0"
)
tokenizer_group
=
TokenizerGroup
(
model
,
enable_lora
=
False
,
...
...
@@ -730,7 +780,13 @@ def test_resolve_content_format_hf_defined(model, expected_format):
)
tokenizer
=
tokenizer_group
.
tokenizer
chat_template
=
tokenizer
.
chat_template
# Test detecting the tokenizer's chat_template
chat_template
=
_resolve_hf_chat_template
(
tokenizer
,
chat_template
=
None
,
tools
=
None
,
trust_remote_code
=
True
,
)
assert
isinstance
(
chat_template
,
str
)
print
(
"[TEXT]"
)
...
...
@@ -740,8 +796,10 @@ def test_resolve_content_format_hf_defined(model, expected_format):
resolved_format
=
resolve_chat_template_content_format
(
None
,
# Test detecting the tokenizer's chat_template
None
,
"auto"
,
tokenizer
,
trust_remote_code
=
True
,
)
assert
resolved_format
==
expected_format
...
...
@@ -791,8 +849,10 @@ def test_resolve_content_format_examples(template_path, expected_format):
resolved_format
=
resolve_chat_template_content_format
(
chat_template
,
None
,
"auto"
,
dummy_tokenizer
,
trust_remote_code
=
True
,
)
assert
resolved_format
==
expected_format
tests/fastsafetensors_loader/__init__.py
0 → 100644
View file @
31f6b24f
tests/fastsafetensors_loader/test_fastsafetensors_loader.py
0 → 100644
View file @
31f6b24f
# SPDX-License-Identifier: Apache-2.0
from
vllm
import
SamplingParams
from
vllm.config
import
LoadFormat
test_model
=
"openai-community/gpt2"
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
# Create a sampling params object.
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
seed
=
0
)
def
test_model_loader_download_files
(
vllm_runner
):
with
vllm_runner
(
test_model
,
load_format
=
LoadFormat
.
FASTSAFETENSORS
)
as
llm
:
deserialized_outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
assert
deserialized_outputs
tests/fastsafetensors_loader/test_weight_utils.py
0 → 100644
View file @
31f6b24f
# SPDX-License-Identifier: Apache-2.0
import
glob
import
tempfile
import
huggingface_hub.constants
import
torch
from
vllm.model_executor.model_loader.weight_utils
import
(
download_weights_from_hf
,
fastsafetensors_weights_iterator
,
safetensors_weights_iterator
)
def
test_fastsafetensors_model_loader
():
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
=
False
download_weights_from_hf
(
"openai-community/gpt2"
,
allow_patterns
=
[
"*.safetensors"
],
cache_dir
=
tmpdir
)
safetensors
=
glob
.
glob
(
f
"
{
tmpdir
}
/**/*.safetensors"
,
recursive
=
True
)
assert
len
(
safetensors
)
>
0
fastsafetensors_tensors
=
{}
hf_safetensors_tensors
=
{}
for
name
,
tensor
in
fastsafetensors_weights_iterator
(
safetensors
,
True
):
fastsafetensors_tensors
[
name
]
=
tensor
for
name
,
tensor
in
safetensors_weights_iterator
(
safetensors
,
True
):
hf_safetensors_tensors
[
name
]
=
tensor
assert
len
(
fastsafetensors_tensors
)
==
len
(
hf_safetensors_tensors
)
for
name
,
fastsafetensors_tensor
in
fastsafetensors_tensors
.
items
():
fastsafetensors_tensor
=
fastsafetensors_tensor
.
to
(
'cpu'
)
assert
fastsafetensors_tensor
.
dtype
==
hf_safetensors_tensors
[
name
].
dtype
assert
fastsafetensors_tensor
.
shape
==
hf_safetensors_tensors
[
name
].
shape
assert
torch
.
all
(
fastsafetensors_tensor
.
eq
(
hf_safetensors_tensors
[
name
]))
if
__name__
==
"__main__"
:
test_fastsafetensors_model_loader
()
tests/kernels/test_marlin_gemm.py
View file @
31f6b24f
...
...
@@ -606,6 +606,51 @@ def test_marlin_qqq_gemm(
assert
max_diff
<
0.04
def
test_marlin_gemm_subset_input
():
quant_type
=
scalar_types
.
uint4b8
group_size
=
128
size_m
,
size_k
,
size_n
=
32
,
1024
,
2048
big_m
=
size_m
*
2
big_k
=
size_k
*
2
a_input
=
rand_data
((
big_m
,
big_k
))[
8
:
size_m
+
8
,
8
:
size_k
+
8
]
b_weight
=
rand_data
((
size_k
,
size_n
))
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
_
=
marlin_quantize
(
b_weight
,
quant_type
,
group_size
,
False
)
marlin_zp
=
marlin_make_empty_g_idx
(
marlin_s
.
device
)
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
output
=
ops
.
gptq_marlin_gemm
(
a_input
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_idx
,
sort_indices
,
workspace
.
scratch
,
quant_type
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
=
True
,
has_zp
=
False
,
use_atomic_add
=
False
,
use_fp32_reduce
=
True
,
is_zp_float
=
False
,
)
output_ref
=
torch
.
matmul
(
a_input
,
w_ref
)
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
assert
max_diff
<
0.04
def
test_marlin_gemm_opcheck
():
size_m
=
2048
size_n
=
4096
...
...
tests/kernels/test_moe.py
View file @
31f6b24f
...
...
@@ -3,8 +3,11 @@
Run `pytest tests/kernels/test_moe.py`.
"""
import
pytest
import
torch
from
torch.nn
import
Parameter
from
torch.nn
import
functional
as
F
from
transformers
import
MixtralConfig
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
...
...
@@ -37,6 +40,7 @@ TOP_KS = [2, 6]
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
EP_SIZE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"padding"
,
[
True
,
False
])
def
test_fused_moe
(
m
:
int
,
n
:
int
,
...
...
@@ -45,6 +49,7 @@ def test_fused_moe(
topk
:
int
,
ep_size
:
int
,
dtype
:
torch
.
dtype
,
padding
:
bool
,
):
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
...
...
@@ -65,16 +70,7 @@ def test_fused_moe(
else
:
e_map
=
None
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
global_num_experts
=
e
,
expert_map
=
e_map
,
renormalize
=
False
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
,
e_map
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
iterative_output
=
iterative_moe
(
a
,
w1
,
w2
,
...
...
@@ -83,6 +79,23 @@ def test_fused_moe(
global_num_experts
=
e
,
expert_map
=
e_map
,
renormalize
=
False
)
# Pad the weight if moe padding is enabled
if
padding
:
w1
=
F
.
pad
(
w1
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
]
torch
.
cuda
.
empty_cache
()
w2
=
F
.
pad
(
w2
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
]
torch
.
cuda
.
empty_cache
()
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
global_num_experts
=
e
,
expert_map
=
e_map
,
renormalize
=
False
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
iterative_output
,
torch_output
,
atol
=
2e-2
,
...
...
@@ -202,8 +215,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"padding"
,
[
True
,
False
])
@
torch
.
inference_mode
()
def
test_mixtral_moe
(
dtype
:
torch
.
dtype
):
def
test_mixtral_moe
(
dtype
:
torch
.
dtype
,
padding
:
bool
):
"""Make sure our Mixtral MoE implementation agrees with the one from
huggingface."""
...
...
@@ -233,6 +247,17 @@ def test_mixtral_moe(dtype: torch.dtype):
# vLLM uses 1D query [num_tokens, hidden_dim]
vllm_inputs
=
hf_inputs
.
flatten
(
0
,
1
)
# Pad the weight if moe padding is enabled
if
padding
:
vllm_moe
.
experts
.
w13_weight
=
Parameter
(
F
.
pad
(
vllm_moe
.
experts
.
w13_weight
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
],
requires_grad
=
False
)
torch
.
cuda
.
empty_cache
()
vllm_moe
.
experts
.
w2_weight
=
Parameter
(
F
.
pad
(
vllm_moe
.
experts
.
w2_weight
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
],
requires_grad
=
False
)
torch
.
cuda
.
empty_cache
()
# Run forward passes for both MoE blocks
hf_states
,
_
=
hf_moe
.
forward
(
hf_inputs
)
vllm_states
=
vllm_moe
.
forward
(
vllm_inputs
)
...
...
tests/tool_use/utils.py
View file @
31f6b24f
...
...
@@ -39,7 +39,10 @@ def ensure_system_prompt(messages: list[dict[str, Any]],
# universal args for all models go here. also good if you need to test locally
# and change type or KV cache quantization or something.
ARGS
:
list
[
str
]
=
[
"--enable-auto-tool-choice"
,
"--max-model-len"
,
"1024"
]
ARGS
:
list
[
str
]
=
[
"--enable-auto-tool-choice"
,
"--max-model-len"
,
"1024"
,
"--max-num-seqs"
,
"256"
]
CONFIGS
:
dict
[
str
,
ServerConfig
]
=
{
"hermes"
:
{
...
...
tests/tpu/test_compilation.py
View file @
31f6b24f
...
...
@@ -5,92 +5,96 @@ import os
import
tempfile
import
depyf
import
pytest
from
vllm.config
import
CompilationLevel
temp_dir
=
tempfile
.
mkdtemp
()
with
depyf
.
prepare_debug
(
temp_dir
):
from
vllm
import
LLM
,
SamplingParams
prompts
=
[
"A robot may not injure a human being"
,
"It is only with the heart that one can see rightly;"
,
"The greatest glory in living lies not in never falling,"
,
]
answers
=
[
" or, through inaction, allow a human being to come to harm."
,
" what is essential is invisible to the eye."
,
" but in rising every time we fall."
,
]
N
=
1
# Currently, top-p sampling is disabled. `top_p` should be 1.0.
sampling_params
=
SamplingParams
(
temperature
=
0.7
,
top_p
=
1.0
,
n
=
N
,
max_tokens
=
16
)
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
# In real workloads, `enforace_eager` should be `False`.
# disable custom dispatcher, let Dynamo takes over
# all the control
llm
=
LLM
(
model
=
"Qwen/Qwen2.5-1.5B-Instruct"
,
max_model_len
=
512
,
max_num_seqs
=
64
,
enforce_eager
=
True
,
compilation_config
=
{
"level"
:
CompilationLevel
.
DYNAMO_AS_IS
})
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
for
output
,
answer
in
zip
(
outputs
,
answers
):
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
assert
generated_text
.
startswith
(
answer
)
compiled_codes
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
temp_dir
,
"__transformed_code*.py"
)))
for
i
,
compiled_code
in
enumerate
(
compiled_codes
):
print
(
"{} file: {}"
.
format
(
i
+
1
,
compiled_code
))
# We should only trigger Dynamo compilation 4 times:
# 1. forward pass (symbolic)
# 2. compute_logits (symbolic)
# 3. forward pass (shape 16)
# 4. forward pass (shape 32)
# and later calls should not trigger Dynamo compilation again.
# NOTE: It might still trigger XLA compilation.
# Check we have 4 compiled codes
assert
len
(
compiled_codes
)
==
4
kv_cache_prefix
=
"kv_cache"
attn_prefix
=
"ragged_paged_attention"
# Check all the compilations are as expected
compiled_fns
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
temp_dir
,
"__compiled_fn*Captured*.py"
)))
for
i
,
compiled_fn
in
enumerate
(
compiled_fns
):
print
(
"{} file: {}"
.
format
(
i
+
1
,
compiled_fn
))
# The first compilation is symbolic, so it should not have any kv_caches
with
open
(
compiled_fns
[
0
])
as
f
:
content
=
f
.
read
()
assert
kv_cache_prefix
not
in
content
# The second compilation is symbolic, so it should not have any kv_caches
with
open
(
compiled_fns
[
1
])
as
f
:
content
=
f
.
read
()
assert
kv_cache_prefix
not
in
content
# The third compilation is shape 16, so it should have kv_caches and the
# ragged_paged_attention
with
open
(
compiled_fns
[
2
])
as
f
:
content
=
f
.
read
()
assert
(
kv_cache_prefix
in
content
and
attn_prefix
in
content
)
# The forth compilation is shape 32, so it should have kv_caches and the
# ragged_paged_attention
with
open
(
compiled_fns
[
3
])
as
f
:
content
=
f
.
read
()
assert
(
kv_cache_prefix
in
content
and
attn_prefix
in
content
)
@
pytest
.
mark
.
skip
(
reason
=
"Not working; needs investigation."
)
def
test_tpu_compilation
():
temp_dir
=
tempfile
.
mkdtemp
()
with
depyf
.
prepare_debug
(
temp_dir
):
from
vllm
import
LLM
,
SamplingParams
prompts
=
[
"A robot may not injure a human being"
,
"It is only with the heart that one can see rightly;"
,
"The greatest glory in living lies not in never falling,"
,
]
answers
=
[
" or, through inaction, allow a human being to come to harm."
,
" what is essential is invisible to the eye."
,
" but in rising every time we fall."
,
]
N
=
1
# Currently, top-p sampling is disabled. `top_p` should be 1.0.
sampling_params
=
SamplingParams
(
temperature
=
0.7
,
top_p
=
1.0
,
n
=
N
,
max_tokens
=
16
)
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
# In real workloads, `enforace_eager` should be `False`.
# disable custom dispatcher, let Dynamo takes over
# all the control
llm
=
LLM
(
model
=
"Qwen/Qwen2.5-1.5B-Instruct"
,
max_model_len
=
512
,
max_num_seqs
=
64
,
enforce_eager
=
True
,
compilation_config
=
{
"level"
:
CompilationLevel
.
DYNAMO_AS_IS
})
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
for
output
,
answer
in
zip
(
outputs
,
answers
):
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
assert
generated_text
.
startswith
(
answer
)
compiled_codes
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
temp_dir
,
"__transformed_code*.py"
)))
for
i
,
compiled_code
in
enumerate
(
compiled_codes
):
print
(
"{} file: {}"
.
format
(
i
+
1
,
compiled_code
))
# We should only trigger Dynamo compilation 4 times:
# 1. forward pass (symbolic)
# 2. compute_logits (symbolic)
# 3. forward pass (shape 16)
# 4. forward pass (shape 32)
# and later calls should not trigger Dynamo compilation again.
# NOTE: It might still trigger XLA compilation.
# Check we have 4 compiled codes
assert
len
(
compiled_codes
)
==
4
kv_cache_prefix
=
"kv_cache"
attn_prefix
=
"ragged_paged_attention"
# Check all the compilations are as expected
compiled_fns
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
temp_dir
,
"__compiled_fn*Captured*.py"
)))
for
i
,
compiled_fn
in
enumerate
(
compiled_fns
):
print
(
"{} file: {}"
.
format
(
i
+
1
,
compiled_fn
))
# The first compilation is symbolic, so it should not have any kv_caches
with
open
(
compiled_fns
[
0
])
as
f
:
content
=
f
.
read
()
assert
kv_cache_prefix
not
in
content
# The second compilation is symbolic, so it should not have any kv_caches
with
open
(
compiled_fns
[
1
])
as
f
:
content
=
f
.
read
()
assert
kv_cache_prefix
not
in
content
# The third compilation is shape 16, so it should have kv_caches and the
# ragged_paged_attention
with
open
(
compiled_fns
[
2
])
as
f
:
content
=
f
.
read
()
assert
(
kv_cache_prefix
in
content
and
attn_prefix
in
content
)
# The forth compilation is shape 32, so it should have kv_caches and the
# ragged_paged_attention
with
open
(
compiled_fns
[
3
])
as
f
:
content
=
f
.
read
()
assert
(
kv_cache_prefix
in
content
and
attn_prefix
in
content
)
tests/v1/engine/test_output_processor.py
View file @
31f6b24f
...
...
@@ -11,11 +11,13 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
STOP_STRINGS
,
DummyOutputProcessorTestVectors
,
MockEngineCore
)
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.sequence
import
PromptLogprobs
,
SampleLogprobs
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.output_processor
import
OutputProcessor
from
vllm.v1.engine.output_processor
import
(
OutputProcessor
,
RequestOutputCollector
)
from
vllm.v1.metrics.stats
import
IterationStats
...
...
@@ -834,3 +836,88 @@ def test_iteration_stats(dummy_test_vectors):
assert
iteration_stats
.
num_prompt_tokens
==
0
assert
iteration_stats
.
num_generation_tokens
==
num_active
@
pytest
.
mark
.
asyncio
async
def
test_request_output_collector
():
NUM_REQS
=
3
TEXT
=
"a"
def
make_outputs
()
->
list
[
RequestOutput
]:
return
[
RequestOutput
(
request_id
=
"my-request-id"
,
prompt
=
None
,
prompt_token_ids
=
[
1
,
2
,
3
],
prompt_logprobs
=
None
,
outputs
=
[
CompletionOutput
(
index
=
0
,
text
=
TEXT
,
token_ids
=
[
idx
],
cumulative_logprob
=
(
idx
+
1
*
1.0
),
logprobs
=
[{
"a"
:
idx
,
"b"
:
idx
}],
finish_reason
=
"length"
if
(
idx
==
NUM_REQS
-
1
)
else
None
,
)
],
finished
=
(
idx
==
NUM_REQS
-
1
),
)
for
idx
in
range
(
NUM_REQS
)
]
collector
=
RequestOutputCollector
(
RequestOutputKind
.
DELTA
)
# CASE 1: Put then get.
outputs
=
make_outputs
()
collector
.
put
(
outputs
[
0
])
output
=
await
collector
.
get
()
assert
not
collector
.
ready
.
is_set
()
assert
collector
.
output
is
None
assert
output
.
outputs
[
0
].
text
==
"a"
assert
output
.
outputs
[
0
].
token_ids
==
[
0
]
# CASE 2: 2 puts then get.
num_to_put
=
2
outputs
=
make_outputs
()
for
i
in
range
(
num_to_put
):
collector
.
put
(
outputs
[
i
])
output
=
await
collector
.
get
()
assert
not
collector
.
ready
.
is_set
()
assert
collector
.
output
is
None
assert
not
output
.
finished
# Text, token_ids, and logprobs should get merged.
assert
output
.
outputs
[
0
].
text
==
TEXT
*
num_to_put
for
tok_0
,
tok_1
in
zip
(
output
.
outputs
[
0
].
token_ids
,
list
(
range
(
num_to_put
))):
assert
tok_0
==
tok_1
assert
len
(
output
.
outputs
[
0
].
logprobs
)
==
num_to_put
# Cumulative logprobs should be the last one.
cumulative_logprob_expected
=
1.0
*
num_to_put
assert
output
.
outputs
[
0
].
cumulative_logprob
==
cumulative_logprob_expected
# CASE 3: Put all 3 (including a finished).
num_to_put
=
3
outputs
=
make_outputs
()
for
i
in
range
(
num_to_put
):
collector
.
put
(
outputs
[
i
])
output
=
await
collector
.
get
()
assert
not
collector
.
ready
.
is_set
()
assert
collector
.
output
is
None
assert
output
.
finished
assert
output
.
outputs
[
0
].
finish_reason
==
"length"
# Text, token_ids, and logprobs should get merged.
assert
output
.
outputs
[
0
].
text
==
TEXT
*
num_to_put
for
tok_0
,
tok_1
in
zip
(
output
.
outputs
[
0
].
token_ids
,
list
(
range
(
num_to_put
))):
assert
tok_0
==
tok_1
assert
len
(
output
.
outputs
[
0
].
logprobs
)
==
num_to_put
# Cumulative logprobs should be the last one.
cumulative_logprob_expected
=
1.0
*
num_to_put
assert
output
.
outputs
[
0
].
cumulative_logprob
==
cumulative_logprob_expected
tests/v1/entrypoints/llm/test_struct_output_generate.py
View file @
31f6b24f
...
...
@@ -13,7 +13,7 @@ from vllm.entrypoints.llm import LLM
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
GuidedDecodingParams
,
SamplingParams
GUIDED_DECODING_BACKENDS_V1
=
[
"xgrammar"
]
GUIDED_DECODING_BACKENDS_V1
=
[
"xgrammar"
,
"guidance"
]
MODELS_TO_TEST
=
[
"Qwen/Qwen2.5-1.5B-Instruct"
,
"mistralai/Ministral-8B-Instruct-2410"
]
...
...
@@ -30,12 +30,13 @@ def test_guided_json_completion(
model_name
:
str
,
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
json
=
sample_json_schema
,
backend
=
guided_decoding_backend
))
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
,
guided_decoding_backend
=
guided_decoding_backend
)
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
json
=
sample_json_schema
))
outputs
=
llm
.
generate
(
prompts
=
[
f
"Give an example JSON for an employee profile "
f
"that fits this schema:
{
sample_json_schema
}
"
...
...
@@ -111,13 +112,14 @@ def test_guided_json_object(
model_name
:
str
,
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
max_tokens
=
100
,
n
=
2
,
guided_decoding
=
GuidedDecodingParams
(
json_object
=
True
,
backend
=
guided_decoding_backend
))
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
,
guided_decoding_backend
=
guided_decoding_backend
)
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
max_tokens
=
100
,
n
=
2
,
guided_decoding
=
GuidedDecodingParams
(
json_object
=
True
))
outputs
=
llm
.
generate
(
prompts
=
(
"Generate a JSON object with curly braces for a person with "
...
...
@@ -137,12 +139,20 @@ def test_guided_json_object(
# Parse to verify it is valid JSON
parsed_json
=
json
.
loads
(
generated_text
)
assert
isinstance
(
parsed_json
,
dict
)
allowed_types
:
tuple
[
type
,
...]
=
(
dict
,
)
if
guided_decoding_backend
==
"xgrammar"
:
# TODO - we are currently too permissive with xgrammar and
# allow # any valid json (typically comes back as a list or
# object). We can fix this by specifying a jsonschema of
# {"type": "object"}, # but we need this fix in a release
# first: https://github.com/mlc-ai/xgrammar/pull/264
allowed_types
=
(
dict
,
list
)
assert
isinstance
(
parsed_json
,
allowed_types
)
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS_V1
)
GUIDED_DECODING_BACKENDS_V1
+
[
"auto"
]
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
MODELS_TO_TEST
)
def
test_guided_json_unsupported_schema
(
monkeypatch
:
pytest
.
MonkeyPatch
,
...
...
@@ -151,21 +161,43 @@ def test_guided_json_unsupported_schema(
model_name
:
str
,
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
json
=
unsupported_json_schema
,
backend
=
guided_decoding_backend
))
with
pytest
.
raises
(
ValueError
,
match
=
"The provided JSON schema contains features "
"not supported by xgrammar."
):
llm
.
generate
(
prompts
=
[
f
"Give an example JSON for an employee profile "
f
"that fits this schema:
{
unsupported_json_schema
}
"
]
*
2
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
,
guided_decoding_backend
=
guided_decoding_backend
)
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
json
=
unsupported_json_schema
))
if
guided_decoding_backend
==
"xgrammar"
:
with
pytest
.
raises
(
ValueError
,
match
=
"The provided JSON schema contains features "
"not supported by xgrammar."
):
llm
.
generate
(
prompts
=
[
f
"Give an example JSON for an employee profile "
f
"that fits this schema:
{
unsupported_json_schema
}
"
]
*
2
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
)
else
:
# This should work for both "guidance" and "auto".
outputs
=
llm
.
generate
(
prompts
=
(
"Give an example JSON object for a grade "
"that fits this schema: "
f
"
{
unsupported_json_schema
}
"
),
sampling_params
=
sampling_params
,
use_tqdm
=
True
)
assert
outputs
is
not
None
for
output
in
outputs
:
assert
output
is
not
None
assert
isinstance
(
output
,
RequestOutput
)
generated_text
=
output
.
outputs
[
0
].
text
assert
generated_text
is
not
None
print
(
generated_text
)
# Parse to verify it is valid JSON
parsed_json
=
json
.
loads
(
generated_text
)
assert
isinstance
(
parsed_json
,
dict
)
@
pytest
.
mark
.
skip_global_cleanup
...
...
@@ -179,13 +211,14 @@ def test_guided_grammar_ebnf(
model_name
:
str
,
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
grammar
=
sample_sql_ebnf
,
backend
=
guided_decoding_backend
))
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
,
guided_decoding_backend
=
guided_decoding_backend
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
grammar
=
sample_sql_ebnf
))
outputs
=
llm
.
generate
(
prompts
=
(
"Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"
),
...
...
@@ -222,13 +255,14 @@ def test_guided_grammar_lark(
model_name
:
str
,
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
grammar
=
sample_sql_lark
,
backend
=
guided_decoding_backend
))
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
,
guided_decoding_backend
=
guided_decoding_backend
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
grammar
=
sample_sql_lark
))
outputs
=
llm
.
generate
(
prompts
=
(
"Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"
),
...
...
@@ -269,16 +303,15 @@ def test_guided_grammar_ebnf_invalid(
model_name
:
str
,
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
grammar
=
"not a grammar"
,
backend
=
guided_decoding_backend
))
with
pytest
.
raises
(
ValueError
,
match
=
"Failed to convert the grammar "
"from Lark to EBNF."
):
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
,
guided_decoding_backend
=
guided_decoding_backend
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
grammar
=
"not a grammar"
))
with
pytest
.
raises
(
ValueError
,
match
=
"Failed to convert the grammar "
):
llm
.
generate
(
prompts
=
(
"Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"
),
...
...
@@ -298,12 +331,13 @@ def test_guided_regex(
model_name
:
str
,
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
guided_decoding
=
GuidedDecodingParams
(
regex
=
sample_regex
,
backend
=
guided_decoding_backend
))
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
,
guided_decoding_backend
=
guided_decoding_backend
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
guided_decoding
=
GuidedDecodingParams
(
regex
=
sample_regex
))
outputs
=
llm
.
generate
(
prompts
=
[
f
"Give an example IPv4 address with this regex:
{
sample_regex
}
"
...
...
@@ -335,12 +369,13 @@ def test_guided_choice_completion(
model_name
:
str
,
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
guided_decoding
=
GuidedDecodingParams
(
choice
=
sample_guided_choice
,
backend
=
guided_decoding_backend
))
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
,
guided_decoding_backend
=
guided_decoding_backend
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
guided_decoding
=
GuidedDecodingParams
(
choice
=
sample_guided_choice
))
outputs
=
llm
.
generate
(
prompts
=
"The best language for type-safe systems programming is "
,
sampling_params
=
sampling_params
,
...
...
tests/v1/sample/test_rejection_sampler.py
View file @
31f6b24f
...
...
@@ -36,6 +36,8 @@ def create_logits_tensor(output_token_ids: list[list[int]],
def
create_sampling_metadata
(
all_greedy
:
bool
,
temperature
:
Optional
[
torch
.
Tensor
]
=
None
,
top_k
:
Optional
[
torch
.
Tensor
]
=
None
,
top_p
:
Optional
[
torch
.
Tensor
]
=
None
,
generators
:
Optional
[
dict
[
int
,
Any
]]
=
None
,
)
->
SamplingMetadata
:
"""Create a v1 sampling metadata object with all_greedy set
...
...
@@ -52,8 +54,8 @@ def create_sampling_metadata(
temperature
=
temperature
,
all_greedy
=
all_greedy
,
all_random
=
not
all_greedy
,
top_p
=
None
,
top_k
=
None
,
top_p
=
top_p
,
top_k
=
top_k
,
min_p
=
torch
.
empty
(
1
,
),
generators
=
generators
,
max_num_logprobs
=
0
,
...
...
@@ -462,3 +464,147 @@ def estimate_rejection_sampling_pdf(
density
=
True
)
return
hist
.
hist
def
_test_masked_logits
(
rejection_sampler
,
batch_size
:
int
,
num_draft_tokens
:
int
,
vocab_size
:
int
,
target_logits
:
torch
.
Tensor
,
unmasked_indices
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
):
# Set up test parameters
num_tokens
=
batch_size
*
num_draft_tokens
# Create random draft probabilities.
draft_probs
=
torch
.
rand
((
num_tokens
,
vocab_size
),
dtype
=
torch
.
float32
,
device
=
DEVICE
)
draft_probs
=
F
.
softmax
(
draft_probs
,
dim
=-
1
)
# Randomly sample draft token ids from draft probs
draft_token_ids
=
torch
.
multinomial
(
draft_probs
,
num_samples
=
1
)
draft_token_ids
=
draft_token_ids
.
reshape
(
batch_size
,
num_draft_tokens
)
draft_token_ids
=
draft_token_ids
.
tolist
()
# Bonus tokens not used but required
bonus_token_ids
=
torch
.
zeros
((
batch_size
,
1
),
dtype
=
torch
.
int64
,
device
=
DEVICE
)
# Create spec decode metadata
spec_decode_metadata
=
SpecDecodeMetadata
.
make_dummy
(
draft_token_ids
,
device
=
DEVICE
,
)
# Run rejection sampling
output_token_ids
=
rejection_sampler
(
spec_decode_metadata
,
draft_probs
=
draft_probs
,
target_logits
=
target_logits
,
bonus_token_ids
=
bonus_token_ids
,
sampling_metadata
=
sampling_metadata
,
)
# Remove bonus tokens and reshape
output_token_ids
=
output_token_ids
[:,
:
-
1
].
flatten
().
tolist
()
# Check that all sampled tokens are within the unmasked indices.
for
i
in
range
(
num_tokens
):
token_id
=
output_token_ids
[
i
]
if
token_id
==
PLACEHOLDER_TOKEN_ID
:
continue
assert
token_id
in
unmasked_indices
[
i
]
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
1
,
5
,
99
])
def
test_top_k
(
rejection_sampler
,
top_k
):
"""Test rejection sampling with top-k sampling"""
vocab_size
=
100
batch_size
=
100
num_draft_tokens
=
3
num_tokens
=
batch_size
*
num_draft_tokens
# Randomly create top-k indices.
top_k_indices
=
[
torch
.
randperm
(
vocab_size
,
device
=
DEVICE
)[:
top_k
]
for
_
in
range
(
num_tokens
)
]
top_k_indices
=
torch
.
stack
(
top_k_indices
)
# Create logits with the uniform distribution.
target_logits
=
torch
.
zeros
((
num_tokens
,
vocab_size
),
device
=
DEVICE
)
# Increment the logits for top-k indices, a little bit more than the other
# ones. If the masking is effective, the non-topk indices will never be
# sampled despite the small difference in logits.
for
i
in
range
(
num_tokens
):
target_logits
[
i
,
top_k_indices
[
i
]]
+=
0.1
# Create sampling metadata
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
sampling_metadata
=
create_sampling_metadata
(
all_greedy
=
False
,
temperature
=
temperature
,
top_k
=
torch
.
tensor
([
top_k
]
*
batch_size
,
device
=
DEVICE
,
dtype
=
torch
.
int64
),
)
_test_masked_logits
(
rejection_sampler
,
batch_size
=
batch_size
,
num_draft_tokens
=
num_draft_tokens
,
vocab_size
=
vocab_size
,
target_logits
=
target_logits
,
unmasked_indices
=
top_k_indices
,
sampling_metadata
=
sampling_metadata
,
)
@
pytest
.
mark
.
parametrize
(
"top_p"
,
[
0.5
,
0.9
,
0.99
])
def
test_top_p
(
rejection_sampler
,
top_p
):
"""Test rejection sampling with top-p sampling"""
vocab_size
=
100
batch_size
=
100
num_draft_tokens
=
3
num_tokens
=
batch_size
*
num_draft_tokens
# Create logits with the uniform distribution.
target_logits
=
torch
.
randn
((
num_tokens
,
vocab_size
),
device
=
DEVICE
)
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
rescaled_logits
=
target_logits
/
temperature
logits_sort
,
logits_idx
=
rescaled_logits
.
sort
(
dim
=-
1
,
descending
=
False
)
probs_sort
=
logits_sort
.
softmax
(
dim
=-
1
)
probs_sum
=
probs_sort
.
cumsum
(
dim
=-
1
)
top_p_mask
=
probs_sum
<=
1
-
top_p
# at least one
top_p_mask
[:,
-
1
]
=
False
# Get the top-p indices.
top_p_indices
=
[]
for
i
in
range
(
num_tokens
):
top_p_indices
.
append
(
logits_idx
[
i
][
~
top_p_mask
[
i
]].
tolist
())
# Create sampling metadata
sampling_metadata
=
create_sampling_metadata
(
all_greedy
=
False
,
temperature
=
temperature
,
top_p
=
torch
.
tensor
([
top_p
]
*
batch_size
,
device
=
DEVICE
,
dtype
=
torch
.
float32
),
)
_test_masked_logits
(
rejection_sampler
,
batch_size
=
batch_size
,
num_draft_tokens
=
num_draft_tokens
,
vocab_size
=
vocab_size
,
target_logits
=
target_logits
,
unmasked_indices
=
top_p_indices
,
sampling_metadata
=
sampling_metadata
,
)
vllm/attention/backends/flash_attn.py
View file @
31f6b24f
...
...
@@ -22,12 +22,13 @@ from vllm.attention.backends.utils import (
compute_slot_mapping_start_idx
,
get_num_prefill_decode_query_kv_tokens
,
get_seq_len_block_table_args
,
is_all_cross_attn_metadata_set
,
is_all_encoder_attn_metadata_set
,
is_block_tables_empty
)
from
vllm.fa_utils
import
get_flash_attn_version
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
)
from
vllm.vllm_flash_attn.fa_utils
import
(
flash_attn_supports_fp8
,
get_flash_attn_version
)
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
...
...
@@ -632,10 +633,13 @@ class FlashAttentionImpl(AttentionImpl):
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
vllm_flash_attn_version
=
get_flash_attn_version
(
requires_alibi
=
self
.
alibi_slopes
is
not
None
)
if
(
is_quantized_kv_cache
(
self
.
kv_cache_dtype
)
and
self
.
vllm_flash_attn_version
!=
3
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
)
and
(
not
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
or
not
flash_attn_supports_fp8
()):
raise
NotImplementedError
(
"Only FlashAttention3 supports FP8 KV cache"
)
f
"FlashAttention does not support
{
self
.
kv_cache_dtype
}
"
"kv-cache on this device "
f
"(FA supports fp8 =
{
flash_attn_supports_fp8
()
}
)."
)
if
logits_soft_cap
is
None
:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap
=
0
...
...
@@ -704,6 +708,10 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap
:
Optional
[
float
]
=
self
.
logits_soft_cap
fp8_attention
=
kv_cache_dtype
.
startswith
(
"fp8"
)
if
fp8_attention
and
not
flash_attn_supports_fp8
():
raise
NotImplementedError
(
"FlashAttention does not support FP8 kv-cache on this device."
)
if
kv_cache
.
numel
()
>
0
:
key_cache
=
kv_cache
[
0
]
value_cache
=
kv_cache
[
1
]
...
...
vllm/attention/backends/mla/common.py
View file @
31f6b24f
...
...
@@ -206,7 +206,6 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
from
vllm.attention.ops.triton_merge_attn_states
import
merge_attn_states
from
vllm.fa_utils
import
get_flash_attn_version
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
RowParallelLinear
,
UnquantizedLinearMethod
)
...
...
@@ -215,6 +214,7 @@ from vllm.model_executor.layers.rotary_embedding import (
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.platforms
import
current_platform
from
vllm.utils
import
async_tensor_h2d
,
cdiv
,
make_tensor_with_pad
,
round_down
from
vllm.vllm_flash_attn.fa_utils
import
get_flash_attn_version
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
...
vllm/compilation/inductor_pass.py
View file @
31f6b24f
# SPDX-License-Identifier: Apache-2.0
import
hashlib
import
importlib.metadata
import
inspect
import
json
import
types
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Callable
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Union
import
torch
from
packaging.version
import
Version
from
torch
import
fx
if
Version
(
importlib
.
metadata
.
version
(
'torch'
))
>=
Version
(
"2.6"
):
from
torch._inductor.custom_graph_pass
import
CustomGraphPass
else
:
# CustomGraphPass is not present in 2.5 or lower, import our version
from
.torch25_custom_graph_pass
import
(
# noqa: yapf
Torch25CustomGraphPass
as
CustomGraphPass
)
class
InductorPass
(
ABC
):
class
InductorPass
(
CustomGraphPass
):
"""
General custom inductor pass interface.
A custom graph pass that uses a hash of its source as the UUID.
This is defined as a convenience and should work in most cases.
"""
@
abstractmethod
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
"""
Execute the pass on the given graph.
"""
raise
NotImplementedError
def
uuid
(
self
)
->
Any
:
"""
Provide a unique identifier for the pass, used in Inductor code cache.
...
...
@@ -48,7 +51,16 @@ class InductorPass(ABC):
else
:
src_str
=
inspect
.
getsource
(
src
.
__class__
)
hasher
.
update
(
src_str
.
encode
(
"utf-8"
))
return
hasher
.
digest
()
return
hasher
.
hexdigest
()
@
staticmethod
def
hash_dict
(
dict_
:
Dict
[
Any
,
Any
]):
"""
Utility method to hash a dictionary, can alternatively be used for uuid.
:return: A sha256 hash of the json rep of the dictionary.
"""
encoded
=
json
.
dumps
(
dict_
,
sort_keys
=
True
).
encode
(
"utf-8"
)
return
hashlib
.
sha256
(
encoded
).
hexdigest
()
class
CallableInductorPass
(
InductorPass
):
...
...
@@ -61,25 +73,10 @@ class CallableInductorPass(InductorPass):
callable
:
Callable
[[
fx
.
Graph
],
None
],
uuid
:
Optional
[
Any
]
=
None
):
self
.
callable
=
callable
if
uuid
is
None
:
uuid
=
InductorPass
.
hash_source
(
callable
)
self
.
_uuid
=
uuid
self
.
_uuid
=
self
.
hash_source
(
callable
)
if
uuid
is
None
else
uuid
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
self
.
callable
(
graph
)
def
uuid
(
self
)
->
Any
:
return
self
.
_uuid
def
__getstate__
(
self
):
"""
Pickling occurs in the Inductor code cache if a pass is not given to
the pass manager but is instead directly added to config as a pass.
See PostGradPassManager for more.
TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
"""
return
self
.
_uuid
def
__setstate__
(
self
,
state
):
raise
ValueError
(
"Cannot unpickle CallableInductorPass"
)
vllm/compilation/pass_manager.py
View file @
31f6b24f
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Dict
,
List
from
typing
import
List
import
torch
from
torch
import
fx
as
fx
from
vllm.config
import
CompilationConfig
...
...
@@ -10,29 +9,18 @@ from vllm.logger import init_logger
from
.fix_functionalization
import
FixFunctionalizationPass
from
.fusion
import
FusionPass
from
.inductor_pass
import
InductorPass
from
.inductor_pass
import
CustomGraphPass
,
InductorPass
from
.noop_elimination
import
NoOpEliminationPass
logger
=
init_logger
(
__name__
)
class
PlaceHolder
:
pass
if
torch
.
__version__
<
"2.6"
:
Parent
=
PlaceHolder
# type: ignore
else
:
Parent
=
torch
.
_inductor
.
custom_graph_pass
.
CustomGraphPass
# type: ignore
class
PostGradPassManager
(
Parent
):
class
PostGradPassManager
(
CustomGraphPass
):
"""
The pass manager for post-grad passes.
It handles configuration, adding custom passes, and running passes.
It also supports pickling, which is used by the Inductor code cache.
TODO(torch==2.6), use CustomGraphPass
(torch._inductor.custom_graph_pass.CustomGraphPass)
It supports uuid for the Inductor code cache. That includes torch<2.6
support using pickling (in .inductor_pass.CustomGraphPass).
The order of the post-grad post-passes is:
1. passes (constructor parameter)
...
...
@@ -67,27 +55,13 @@ class PostGradPassManager(Parent):
self
.
passes
.
append
(
pass_
)
def
uuid
(
self
):
return
self
.
__getstate__
()
def
__getstate__
(
self
)
->
Dict
[
str
,
List
[
Any
]]:
"""
Custom pickling for the pass manager, as some passes cannot be pickled.
Pickling occurs because the pass manager is set as the value of
`config["post_grad_custom_post_pass"]` in the Inductor config.
The config is pickled to act as a key in the Inductor code cache.
Any other passes in the config are pickled as well.
TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
The PostGradPassManager is set as a custom pass in the Inductor and
affects compilation caching. Its uuid depends on the UUIDs of all
dependent passes and the pass config. See InductorPass for more info.
"""
state
=
{
"pass_config"
:
self
.
pass_config
.
uuid
(),
"passes"
:
[]}
for
pass_
in
self
.
passes
:
state
[
"passes"
].
append
(
pass_
.
uuid
())
state
[
"passes"
].
append
(
self
.
fix_functionalization
.
uuid
())
return
state
def
__setstate__
(
self
,
state
):
"""
Do not allow unpickling of the pass manager.
If this is needed in the future, it should properly pickle the passes.
"""
raise
ValueError
(
"Cannot unpickle PostGradPassManager"
)
return
InductorPass
.
hash_dict
(
state
)
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment