Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
53076d70
Commit
53076d70
authored
Mar 24, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.2' into v0.8.2-ori
parents
322a0be6
9c5c81b0
Changes
219
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
424 additions
and
150 deletions
+424
-150
tests/entrypoints/llm/test_collective_rpc.py
tests/entrypoints/llm/test_collective_rpc.py
+2
-11
tests/entrypoints/llm/test_guided_generate.py
tests/entrypoints/llm/test_guided_generate.py
+3
-1
tests/entrypoints/openai/test_chat_with_tool_reasoning.py
tests/entrypoints/openai/test_chat_with_tool_reasoning.py
+145
-0
tests/kernels/test_attention_selector.py
tests/kernels/test_attention_selector.py
+4
-12
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+68
-8
tests/kernels/test_rocm_attention_selector.py
tests/kernels/test_rocm_attention_selector.py
+1
-1
tests/lora/test_add_lora.py
tests/lora/test_add_lora.py
+8
-40
tests/lora/test_llama_tp.py
tests/lora/test_llama_tp.py
+8
-6
tests/lora/test_tokenizer_group.py
tests/lora/test_tokenizer_group.py
+2
-4
tests/model_executor/test_enabled_custom_ops.py
tests/model_executor/test_enabled_custom_ops.py
+28
-1
tests/model_executor/test_guided_processors.py
tests/model_executor/test_guided_processors.py
+3
-1
tests/models/decoder_only/language/test_models.py
tests/models/decoder_only/language/test_models.py
+40
-10
tests/models/decoder_only/vision_language/test_models.py
tests/models/decoder_only/vision_language/test_models.py
+13
-0
tests/models/decoder_only/vision_language/vlm_utils/custom_inputs.py
...s/decoder_only/vision_language/vlm_utils/custom_inputs.py
+18
-0
tests/models/encoder_decoder/vision_language/test_mllama.py
tests/models/encoder_decoder/vision_language/test_mllama.py
+0
-4
tests/models/registry.py
tests/models/registry.py
+2
-0
tests/plugins_tests/test_scheduler_plugins.py
tests/plugins_tests/test_scheduler_plugins.py
+1
-1
tests/spec_decode/e2e/conftest.py
tests/spec_decode/e2e/conftest.py
+1
-1
tests/spec_decode/e2e/test_compatibility.py
tests/spec_decode/e2e/test_compatibility.py
+20
-9
tests/spec_decode/e2e/test_eagle_correctness.py
tests/spec_decode/e2e/test_eagle_correctness.py
+57
-40
No files found.
tests/entrypoints/llm/test_collective_rpc.py
View file @
53076d70
...
...
@@ -21,18 +21,9 @@ def test_collective_rpc(tp_size, backend):
def
echo_rank
(
self
):
return
self
.
rank
from
vllm.worker.worker
import
Worker
class
MyWorker
(
Worker
):
def
echo_rank
(
self
):
return
self
.
rank
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
,
enforce_eager
=
True
,
load_format
=
"dummy"
,
tensor_parallel_size
=
tp_size
,
distributed_executor_backend
=
backend
,
worker_cls
=
MyWorker
)
for
method
in
[
"echo_rank"
,
echo_rank
]:
assert
llm
.
collective_rpc
(
method
)
==
list
(
range
(
tp_size
))
distributed_executor_backend
=
backend
)
assert
llm
.
collective_rpc
(
echo_rank
)
==
list
(
range
(
tp_size
))
tests/entrypoints/llm/test_guided_generate.py
View file @
53076d70
...
...
@@ -14,7 +14,9 @@ from vllm.outputs import RequestOutput
from
vllm.sampling_params
import
GuidedDecodingParams
,
SamplingParams
MODEL_NAME
=
"Qwen/Qwen2.5-1.5B-Instruct"
GUIDED_DECODING_BACKENDS
=
[
"outlines"
,
"lm-format-enforcer"
,
"xgrammar"
]
GUIDED_DECODING_BACKENDS
=
[
"outlines"
,
"lm-format-enforcer"
,
"xgrammar"
,
"guidance"
]
@
pytest
.
fixture
(
scope
=
"module"
)
...
...
tests/entrypoints/openai/test_chat_with_tool_reasoning.py
0 → 100644
View file @
53076d70
# SPDX-License-Identifier: Apache-2.0
import
openai
# use the official client for correctness check
import
pytest
import
pytest_asyncio
from
...utils
import
RemoteOpenAIServer
# a reasoning and tool calling model
MODEL_NAME
=
"Qwen/QwQ-32B"
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
():
# noqa: F811
args
=
[
"--max-model-len"
,
"8192"
,
"--enforce-eager"
,
"--enable-reasoning"
,
"--reasoning-parser"
,
"deepseek_r1"
,
"--enable-auto-tool-choice"
,
"--tool-call-parser"
,
"hermes"
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
yield
remote_server
@
pytest_asyncio
.
fixture
async
def
client
(
server
):
async
with
server
.
get_async_client
()
as
async_client
:
yield
async_client
TOOLS
=
[{
"type"
:
"function"
,
"function"
:
{
"name"
:
"get_current_weather"
,
"description"
:
"Get the current weather in a given location"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"city"
:
{
"type"
:
"string"
,
"description"
:
"The city to find the weather for, e.g. 'San Francisco'"
},
"state"
:
{
"type"
:
"string"
,
"description"
:
"the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'"
},
"unit"
:
{
"type"
:
"string"
,
"description"
:
"The unit to fetch the temperature in"
,
"enum"
:
[
"celsius"
,
"fahrenheit"
]
}
},
"required"
:
[
"city"
,
"state"
,
"unit"
]
}
}
}]
MESSAGES
=
[{
"role"
:
"user"
,
"content"
:
"Hi! How are you doing today?"
},
{
"role"
:
"assistant"
,
"content"
:
"I'm doing well! How can I help you?"
},
{
"role"
:
"user"
,
"content"
:
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
}]
FUNC_NAME
=
"get_current_weather"
FUNC_ARGS
=
"""{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}"""
def
extract_reasoning_and_calls
(
chunks
:
list
):
reasoning_content
=
""
tool_call_idx
=
-
1
arguments
=
[]
function_names
=
[]
for
chunk
in
chunks
:
if
chunk
.
choices
[
0
].
delta
.
tool_calls
:
tool_call
=
chunk
.
choices
[
0
].
delta
.
tool_calls
[
0
]
if
tool_call
.
index
!=
tool_call_idx
:
tool_call_idx
=
chunk
.
choices
[
0
].
delta
.
tool_calls
[
0
].
index
arguments
.
append
(
""
)
function_names
.
append
(
""
)
if
tool_call
.
function
:
if
tool_call
.
function
.
name
:
function_names
[
tool_call_idx
]
=
tool_call
.
function
.
name
if
tool_call
.
function
.
arguments
:
arguments
[
tool_call_idx
]
+=
tool_call
.
function
.
arguments
else
:
if
hasattr
(
chunk
.
choices
[
0
].
delta
,
"reasoning_content"
):
reasoning_content
+=
chunk
.
choices
[
0
].
delta
.
reasoning_content
return
reasoning_content
,
arguments
,
function_names
# test streaming
@
pytest
.
mark
.
asyncio
async
def
test_chat_streaming_of_tool_and_reasoning
(
client
:
openai
.
AsyncOpenAI
):
stream
=
await
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
messages
=
MESSAGES
,
tools
=
TOOLS
,
temperature
=
0.0
,
stream
=
True
,
)
chunks
=
[]
async
for
chunk
in
stream
:
chunks
.
append
(
chunk
)
reasoning_content
,
arguments
,
function_names
=
extract_reasoning_and_calls
(
chunks
)
assert
len
(
reasoning_content
)
>
0
assert
len
(
function_names
)
>
0
and
function_names
[
0
]
==
FUNC_NAME
assert
len
(
arguments
)
>
0
and
arguments
[
0
]
==
FUNC_ARGS
# test full generate
@
pytest
.
mark
.
asyncio
async
def
test_chat_full_of_tool_and_reasoning
(
client
:
openai
.
AsyncOpenAI
):
tool_calls
=
await
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
messages
=
MESSAGES
,
tools
=
TOOLS
,
temperature
=
0.0
,
stream
=
False
,
)
assert
len
(
tool_calls
.
choices
[
0
].
message
.
reasoning_content
)
>
0
assert
tool_calls
.
choices
[
0
].
message
.
tool_calls
[
0
].
function
.
name
\
==
FUNC_NAME
assert
tool_calls
.
choices
[
0
].
message
.
tool_calls
[
0
].
function
.
arguments
\
==
FUNC_ARGS
tests/kernels/test_attention_selector.py
View file @
53076d70
# SPDX-License-Identifier: Apache-2.0
from
unittest.mock
import
Mock
,
patch
from
unittest.mock
import
patch
import
pytest
import
torch
...
...
@@ -8,7 +8,6 @@ import torch
from
vllm.attention.selector
import
_cached_get_attn_backend
,
get_attn_backend
from
vllm.platforms.cpu
import
CpuPlatform
from
vllm.platforms.cuda
import
CudaPlatform
from
vllm.platforms.openvino
import
OpenVinoPlatform
from
vllm.platforms.rocm
import
RocmPlatform
from
vllm.utils
import
STR_BACKEND_ENV_VAR
,
STR_FLASH_ATTN_VAL
,
STR_INVALID_VAL
...
...
@@ -21,9 +20,9 @@ def clear_cache():
@
pytest
.
mark
.
parametrize
(
"name"
,
[
"TORCH_SDPA"
,
"ROCM_FLASH"
,
"XFORMERS"
,
"FLASHINFER"
,
"OPENVINO"
])
"name"
,
[
"TORCH_SDPA"
,
"ROCM_FLASH"
,
"XFORMERS"
,
"FLASHINFER"
])
@
pytest
.
mark
.
parametrize
(
"use_v1"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"openvino"
,
"hip"
,
"cuda"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"hip"
,
"cuda"
])
def
test_env
(
name
:
str
,
use_v1
:
bool
,
...
...
@@ -49,15 +48,8 @@ def test_env(
RocmPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
EXPECTED
=
"
ROCM
_ATTN_VLLM_V1"
if
use_v1
else
"ROCM_FLASH"
EXPECTED
=
"
TRITON
_ATTN_VLLM_V1"
if
use_v1
else
"ROCM_FLASH"
assert
backend
.
get_name
()
==
EXPECTED
elif
device
==
"openvino"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
OpenVinoPlatform
()),
patch
.
dict
(
'sys.modules'
,
{
'openvino'
:
Mock
()}):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
get_name
()
==
"OPENVINO"
else
:
if
name
in
[
"XFORMERS"
,
"FLASHINFER"
]:
with
patch
(
"vllm.attention.selector.current_platform"
,
...
...
tests/kernels/test_flash_attn.py
View file @
53076d70
...
...
@@ -15,6 +15,7 @@ NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES
=
[
128
,
256
]
BLOCK_SIZES
=
[
16
,
32
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
QDTYPES
=
[
None
,
torch
.
float8_e4m3fn
]
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS
=
[
32768
,
2048
]
...
...
@@ -85,6 +86,7 @@ def ref_paged_attn(
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
,
256
])
@
pytest
.
mark
.
parametrize
(
"fa_version"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
QDTYPES
)
@
torch
.
inference_mode
()
def
test_flash_attn_with_paged_kv
(
use_out
:
bool
,
...
...
@@ -97,11 +99,15 @@ def test_flash_attn_with_paged_kv(
num_blocks
:
int
,
sliding_window
:
Optional
[
int
],
fa_version
:
int
,
q_dtype
:
Optional
[
torch
.
dtype
],
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
if
not
is_fa_version_supported
(
fa_version
):
pytest
.
skip
(
f
"Flash attention version
{
fa_version
}
not supported due "
f
"to:
\"
{
fa_version_unsupported_reason
(
fa_version
)
}
\"
"
)
if
q_dtype
is
not
None
and
(
dtype
!=
torch
.
bfloat16
or
fa_version
==
2
):
pytest
.
skip
(
"Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type"
)
current_platform
.
seed_everything
(
0
)
num_seqs
=
len
(
kv_lens
)
...
...
@@ -130,10 +136,28 @@ def test_flash_attn_with_paged_kv(
q
=
query
.
unsqueeze
(
1
)
out
=
torch
.
empty_like
(
q
)
if
use_out
else
None
maybe_quantized_query
=
q
maybe_quantized_key_cache
=
key_cache
maybe_quantized_value_cache
=
value_cache
q_descale
=
None
k_descale
=
None
v_descale
=
None
if
q_dtype
is
not
None
:
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
maybe_quantized_query
=
query
.
to
(
q_dtype
)
maybe_quantized_key_cache
=
key_cache
.
to
(
q_dtype
)
maybe_quantized_value_cache
=
value_cache
.
to
(
q_dtype
)
scale_shape
=
(
num_seqs
,
num_kv_heads
)
q_descale
=
torch
.
ones
(
scale_shape
,
dtype
=
torch
.
float32
)
k_descale
=
torch
.
ones
(
scale_shape
,
dtype
=
torch
.
float32
)
v_descale
=
torch
.
ones
(
scale_shape
,
dtype
=
torch
.
float32
)
output
=
flash_attn_with_kvcache
(
q
=
q
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
q
=
maybe_quantized_query
,
k_cache
=
maybe_quantized_
key_cache
,
v_cache
=
maybe_quantized_
value_cache
,
out
=
out
,
softmax_scale
=
scale
,
causal
=
True
,
...
...
@@ -142,10 +166,17 @@ def test_flash_attn_with_paged_kv(
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
window_size
=
window_size
,
fa_version
=
fa_version
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
output
=
output
if
not
use_out
else
out
output
=
output
.
squeeze
(
1
)
atol
,
rtol
=
1.5e-2
,
1e-2
if
q_dtype
is
not
None
:
atol
,
rtol
=
1.5e-1
,
1.5e-1
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
...
...
@@ -155,7 +186,7 @@ def test_flash_attn_with_paged_kv(
scale
=
scale
,
soft_cap
=
soft_cap
,
sliding_window
=
sliding_window
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
2e-2
,
rtol
=
1e-2
),
\
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
...
...
@@ -171,6 +202,7 @@ def test_flash_attn_with_paged_kv(
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"fa_version"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
QDTYPES
)
@
torch
.
inference_mode
()
def
test_varlen_with_paged_kv
(
use_out
:
bool
,
...
...
@@ -183,11 +215,15 @@ def test_varlen_with_paged_kv(
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
fa_version
:
int
,
q_dtype
:
Optional
[
torch
.
dtype
],
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
if
not
is_fa_version_supported
(
fa_version
):
pytest
.
skip
(
f
"Flash attention version
{
fa_version
}
not supported due "
f
"to:
\"
{
fa_version_unsupported_reason
(
fa_version
)
}
\"
"
)
if
q_dtype
is
not
None
and
(
dtype
!=
torch
.
bfloat16
or
fa_version
==
2
):
pytest
.
skip
(
"Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type"
)
current_platform
.
seed_everything
(
0
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
...
...
@@ -223,10 +259,28 @@ def test_varlen_with_paged_kv(
dtype
=
torch
.
int32
)
out
=
torch
.
empty_like
(
query
)
if
use_out
else
None
maybe_quantized_query
=
query
maybe_quantized_key_cache
=
key_cache
maybe_quantized_value_cache
=
value_cache
q_descale
=
None
k_descale
=
None
v_descale
=
None
if
q_dtype
is
not
None
:
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
maybe_quantized_query
=
query
.
to
(
q_dtype
)
maybe_quantized_key_cache
=
key_cache
.
to
(
q_dtype
)
maybe_quantized_value_cache
=
value_cache
.
to
(
q_dtype
)
scale_shape
=
(
num_seqs
,
num_kv_heads
)
q_descale
=
torch
.
ones
(
scale_shape
,
dtype
=
torch
.
float32
)
k_descale
=
torch
.
ones
(
scale_shape
,
dtype
=
torch
.
float32
)
v_descale
=
torch
.
ones
(
scale_shape
,
dtype
=
torch
.
float32
)
output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
q
=
maybe_quantized_
query
,
k
=
maybe_quantized_
key_cache
,
v
=
maybe_quantized_
value_cache
,
out
=
out
,
cu_seqlens_q
=
cu_query_lens
,
seqused_k
=
kv_lens
,
...
...
@@ -238,6 +292,9 @@ def test_varlen_with_paged_kv(
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
fa_version
=
fa_version
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
output
=
output
if
not
use_out
else
out
...
...
@@ -252,5 +309,8 @@ def test_varlen_with_paged_kv(
sliding_window
=
sliding_window
,
soft_cap
=
soft_cap
,
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
2e-2
,
rtol
=
1e-2
),
\
atol
,
rtol
=
1.5e-2
,
1e-2
if
q_dtype
is
not
None
:
atol
,
rtol
=
1.5e-1
,
1.5e-1
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
tests/kernels/test_rocm_attention_selector.py
View file @
53076d70
...
...
@@ -26,7 +26,7 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
# Test standard ROCm attention
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
(
backend
.
get_name
()
==
"ROCM_FLASH"
or
backend
.
get_name
()
==
"
ROCM
_ATTN_VLLM_V1"
)
or
backend
.
get_name
()
==
"
TRITON
_ATTN_VLLM_V1"
)
# mla test for deepseek related
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
16
,
False
,
...
...
tests/lora/test_add_lora.py
View file @
53076d70
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
time
from
pathlib
import
Path
import
pytest
from
huggingface_hub
import
snapshot_download
import
vllm.envs
as
env
from
vllm.engine.arg_utils
import
AsyncEngineArgs
...
...
@@ -13,35 +11,9 @@ from vllm.lora.request import LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
merge_async_iterators
MODEL_PATH
=
"meta-llama/Llama-2-7b-hf"
LORA_MODULE_DOWNLOAD_PATH
=
None
# Populated by download_and_prepare_lora_module() #noqa
LORA_RANK
=
8
DEFAULT_MAX_LORAS
=
16
*
3
def
download_and_prepare_lora_module
():
"""
Request submission is expensive when the LoRA adapters have their own
tokenizers. This is because, for each request with a new LoRA adapter ID,
the front-end loads the tokenizer from disk.
In this test, as we are comparing request processing times, we want to
minimize any extra activity. To this effect, we download the LoRA
adapter and remove all the tokenizer files, so the engine will default
to the base model tokenizer.
"""
global
LORA_MODULE_DOWNLOAD_PATH
LORA_MODULE_HF_PATH
=
"yard1/llama-2-7b-sql-lora-test"
LORA_MODULE_DOWNLOAD_PATH
=
snapshot_download
(
repo_id
=
LORA_MODULE_HF_PATH
)
tokenizer_files
=
[
'added_tokens.json'
,
'tokenizer_config.json'
,
'tokenizer.json'
,
'tokenizer.model'
]
for
tokenizer_file
in
tokenizer_files
:
del_path
=
Path
(
LORA_MODULE_DOWNLOAD_PATH
)
/
tokenizer_file
del_path
.
unlink
(
missing_ok
=
True
)
MODEL_PATH
=
"THUDM/chatglm3-6b"
LORA_RANK
=
64
DEFAULT_MAX_LORAS
=
4
*
3
@
pytest
.
fixture
(
autouse
=
True
)
...
...
@@ -52,11 +24,9 @@ def v1(run_with_both_engines_lora):
pass
def
get_lora_requests
()
->
list
[
LoRARequest
]:
def
get_lora_requests
(
lora_path
)
->
list
[
LoRARequest
]:
lora_requests
:
list
[
LoRARequest
]
=
[
LoRARequest
(
lora_name
=
f
"
{
i
}
"
,
lora_int_id
=
i
,
lora_path
=
LORA_MODULE_DOWNLOAD_PATH
)
LoRARequest
(
lora_name
=
f
"
{
i
}
"
,
lora_int_id
=
i
,
lora_path
=
lora_path
)
for
i
in
range
(
1
,
DEFAULT_MAX_LORAS
+
1
)
]
return
lora_requests
...
...
@@ -93,7 +63,7 @@ async def requests_processing_time(llm,
@
pytest
.
mark
.
asyncio
async
def
test_add_lora
():
async
def
test_add_lora
(
chatglm3_lora_files
):
"""
The add_lora function is used to pre-load some LoRA adapters into the
engine in anticipation of future requests using these adapters. To test
...
...
@@ -103,10 +73,7 @@ async def test_add_lora():
We measure the request processing time in both cases and expect the time
to be lesser in the case with add_lora() calls.
"""
download_and_prepare_lora_module
()
lora_requests
:
list
[
LoRARequest
]
=
get_lora_requests
()
lora_requests
:
list
[
LoRARequest
]
=
get_lora_requests
(
chatglm3_lora_files
)
max_loras
=
len
(
set
([
lr
.
lora_int_id
for
lr
in
lora_requests
]))
# Create engine in eager-mode. Due to high max_loras, the CI can
...
...
@@ -118,6 +85,7 @@ async def test_add_lora():
max_lora_rank
=
LORA_RANK
,
max_model_len
=
128
,
gpu_memory_utilization
=
0.8
,
#avoid OOM
trust_remote_code
=
True
,
enforce_eager
=
True
)
# The run_with_both_engines_lora fixture sets up the `VLLM_USE_V1`
...
...
tests/lora/test_llama_tp.py
View file @
53076d70
...
...
@@ -84,12 +84,14 @@ def v1(run_with_both_engines_lora):
@
create_new_process_for_each_test
()
def
test_llama_lora
(
sql_lora_files
):
llm
=
vllm
.
LLM
(
MODEL_PATH
,
enable_lora
=
True
,
max_num_seqs
=
16
,
max_loras
=
4
,
tensor_parallel_size
=
1
,
enable_chunked_prefill
=
True
)
llm
=
vllm
.
LLM
(
MODEL_PATH
,
enable_lora
=
True
,
# also test odd max_num_seqs
max_num_seqs
=
13
,
max_loras
=
4
,
tensor_parallel_size
=
1
,
enable_chunked_prefill
=
True
)
generate_and_test
(
llm
,
sql_lora_files
)
...
...
tests/lora/test_tokenizer_group.py
View file @
53076d70
...
...
@@ -24,12 +24,10 @@ async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type):
)
lora_request
=
LoRARequest
(
"1"
,
1
,
sql_lora_files
)
assert
reference_tokenizer
.
encode
(
"prompt"
)
==
tokenizer_group
.
encode
(
request_id
=
"request_id"
,
prompt
=
"prompt"
,
lora_request
=
lora_request
)
prompt
=
"prompt"
,
lora_request
=
lora_request
)
assert
reference_tokenizer
.
encode
(
"prompt"
)
==
await
tokenizer_group
.
encode_async
(
request_id
=
"request_id"
,
prompt
=
"prompt"
,
lora_request
=
lora_request
)
prompt
=
"prompt"
,
lora_request
=
lora_request
)
assert
isinstance
(
tokenizer_group
.
get_lora_tokenizer
(
None
),
PreTrainedTokenizerBase
)
assert
tokenizer_group
.
get_lora_tokenizer
(
...
...
tests/model_executor/test_enabled_custom_ops.py
View file @
53076d70
...
...
@@ -7,7 +7,10 @@ from vllm.model_executor.custom_op import CustomOp
from
vllm.model_executor.layers.activation
import
(
GeluAndMul
,
ReLUSquaredActivation
,
SiluAndMul
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
(
RMSNorm
,
dispatch_cuda_rmsnorm_func
,
fused_add_rms_norm
,
rms_norm
,
rocm_aiter_fused_add_rms_norm
,
rocm_aiter_rms_norm
)
from
vllm.platforms
import
current_platform
# Registered subclass for test
...
...
@@ -87,3 +90,27 @@ def test_enabled_ops_invalid(env: str):
custom_ops
=
env
.
split
(
","
)))
with
set_current_vllm_config
(
vllm_config
):
RMSNorm
(
1024
).
enabled
()
@
pytest
.
mark
.
parametrize
(
"add_residual"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_rocm_aiter"
,
[
"0"
,
"1"
])
@
pytest
.
mark
.
parametrize
(
"use_rocm_aiter_norm"
,
[
"0"
,
"1"
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_rocm
(),
reason
=
"AITER is a feature exclusive for ROCm"
)
def
test_rms_norm_dispatch
(
add_residual
:
bool
,
use_rocm_aiter
:
str
,
use_rocm_aiter_norm
:
str
,
monkeypatch
):
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
use_rocm_aiter
)
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER_RMSNORM"
,
use_rocm_aiter_norm
)
rms_norm_func
=
dispatch_cuda_rmsnorm_func
(
add_residual
)
if
not
add_residual
:
if
current_platform
.
is_rocm
()
and
int
(
use_rocm_aiter
)
and
int
(
use_rocm_aiter_norm
):
assert
rms_norm_func
==
rocm_aiter_rms_norm
else
:
assert
rms_norm_func
==
rms_norm
elif
current_platform
.
is_rocm
()
and
int
(
use_rocm_aiter
)
and
int
(
use_rocm_aiter_norm
):
assert
rms_norm_func
==
rocm_aiter_fused_add_rms_norm
else
:
assert
rms_norm_func
==
fused_add_rms_norm
tests/model_executor/test_guided_processors.py
View file @
53076d70
...
...
@@ -16,7 +16,9 @@ from vllm.model_executor.guided_decoding.outlines_logits_processors import (
from
vllm.sampling_params
import
GuidedDecodingParams
MODEL_NAME
=
'HuggingFaceH4/zephyr-7b-beta'
GUIDED_DECODING_BACKENDS
=
[
"outlines"
,
"lm-format-enforcer"
,
"xgrammar"
]
GUIDED_DECODING_BACKENDS
=
[
"outlines"
,
"lm-format-enforcer"
,
"xgrammar"
,
"guidance"
]
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT
=
[
"outlines"
,
"xgrammar"
]
REASONING_MODEL_NAME
=
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
...
...
tests/models/decoder_only/language/test_models.py
View file @
53076d70
...
...
@@ -3,7 +3,11 @@
Run `pytest tests/models/test_models.py`.
"""
import
pytest
import
torch
from
vllm.platforms
import
current_platform
from
...utils
import
check_logprobs_close
...
...
@@ -13,7 +17,21 @@ from ...utils import check_logprobs_close
# https://github.com/vllm-project/vllm/issues/14524
REQUIRES_V0
=
[
"microsoft/phi-2"
,
"stabilityai/stablelm-3b-4e1t"
]
# This list contains the model that are using AITER kernel.
# Skip model that are not using AITER tests.
# When more AITER kernels are added, this list will not be
# needed as all the models will be calling AITER kernels
# in parts of the operators
AITER_MODEL_LIST
=
[
"meta-llama/Llama-3.2-1B-Instruct"
,
"openbmb/MiniCPM3-4B"
,
"Qwen/Qwen-7B"
,
"Qwen/Qwen2.5-0.5B-Instruct"
,
"ehristoforu/Falcon3-MoE-2x7B-Insruct"
,
]
# @maybe_test_rocm_aiter
@
pytest
.
mark
.
parametrize
(
"model"
,
[
...
...
@@ -69,19 +87,24 @@ REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"]
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
monkeypatch
,
)
->
None
:
@
pytest
.
mark
.
parametrize
(
"use_rocm_aiter"
,
[
True
,
False
]
if
current_platform
.
is_rocm
()
else
[
False
])
def
test_models
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
use_rocm_aiter
:
bool
,
monkeypatch
)
->
None
:
if
model
in
REQUIRES_V0
:
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
if
use_rocm_aiter
and
(
model
in
AITER_MODEL_LIST
):
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
elif
use_rocm_aiter
and
model
not
in
AITER_MODEL_LIST
:
# Skip model that are not using AITER tests.
# When more AITER kernels are added, this list will not be
# needed as all the models will be calling AITER kernels
# in parts of the operators
pytest
.
skip
(
f
"Skipping '
{
model
}
' model test with AITER kernel."
)
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
if
model
.
startswith
(
"THUDM/chatglm3"
):
hf_model
.
model
.
get_output_embeddings
=
lambda
:
\
...
...
@@ -100,3 +123,10 @@ def test_models(
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
if
use_rocm_aiter
:
# this is to ensure that vllm engine
# has deallocated the memory before running the next
# unit tests. On ROCm, when using AITER
# the memory might not be deallocated completely
# before running the next test case
torch
.
cuda
.
synchronize
()
tests/models/decoder_only/vision_language/test_models.py
View file @
53076d70
...
...
@@ -508,6 +508,19 @@ VLM_TEST_SETTINGS = {
limit_mm_per_prompt
=
{
"image"
:
4
},
)],
),
# regression test for https://github.com/vllm-project/vllm/issues/15122
"qwen2_5_vl-windows-attention"
:
VLMTestInfo
(
models
=
[
"Qwen/Qwen2.5-VL-3B-Instruct"
],
test_type
=
VLMTestType
.
CUSTOM_INPUTS
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
auto_cls
=
AutoModelForVision2Seq
,
vllm_output_post_proc
=
model_utils
.
qwen2_vllm_to_hf_output
,
custom_test_opts
=
[
CustomTestOptions
(
inputs
=
custom_inputs
.
windows_attention_image_qwen2_5_vl
(),
limit_mm_per_prompt
=
{
"image"
:
1
},
)],
),
}
# yapf: enable
...
...
tests/models/decoder_only/vision_language/vlm_utils/custom_inputs.py
View file @
53076d70
# SPDX-License-Identifier: Apache-2.0
"""Custom input builders for edge-cases in different models."""
from
io
import
BytesIO
from
typing
import
Callable
import
requests
from
PIL
import
Image
from
vllm.multimodal.image
import
rescale_image_size
from
vllm.multimodal.video
import
(
rescale_video_size
,
resize_video
,
sample_frames_from_video
)
...
...
@@ -102,3 +106,17 @@ def different_patch_input_cases_internvl():
build_single_image_inputs
(
images
,
formatted_sprompts
,
wrapped_sf
),
build_multi_image_inputs
([
images
],
formatted_mprompts
,
wrapped_sf
),
]
def
windows_attention_image_qwen2_5_vl
():
# image from regression issue: https://github.com/vllm-project/vllm/issues/15122
image_url
=
"https://aomediacodec.github.io/av1-avif/testFiles/Link-U/hato.jpg"
image
=
Image
.
open
(
BytesIO
(
requests
.
get
(
image_url
).
content
))
question
=
"Describe the image."
img_prompt
=
"<|vision_start|><|image_pad|><|vision_end|>"
prompt
=
(
f
"<|im_start|>User
\n
{
img_prompt
}{
question
}
<|im_end|>
\n
"
"<|im_start|>assistant
\n
"
)
wrapped_sf
=
ImageSizeWrapper
(
type
=
SizeType
.
SIZE_FACTOR
,
data
=
[
0.5
])
return
build_single_image_inputs
([
image
],
[
prompt
],
wrapped_sf
)
tests/models/encoder_decoder/vision_language/test_mllama.py
View file @
53076d70
...
...
@@ -215,7 +215,6 @@ def _run_test(
max_num_seqs
=
2
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
enforce_eager
=
True
,
limit_mm_per_prompt
=
{
"image"
:
_LIMIT_IMAGE_PER_PROMPT
})
as
vllm_model
:
vllm_outputs_per_image
=
[
...
...
@@ -425,7 +424,6 @@ def test_bnb_regression(
dtype
=
dtype
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
enforce_eager
=
True
,
quantization
=
"bitsandbytes"
,
load_format
=
"bitsandbytes"
,
)
...
...
@@ -481,7 +479,6 @@ def test_explicit_implicit_prompt(
max_model_len
=
4096
,
max_num_seqs
=
2
,
tensor_parallel_size
=
1
,
enforce_eager
=
True
,
)
sampling_params
=
SamplingParams
(
temperature
=
0
,
...
...
@@ -513,7 +510,6 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
max_model_len
=
4096
,
max_num_seqs
=
2
,
tensor_parallel_size
=
1
,
enforce_eager
=
True
,
limit_mm_per_prompt
=
{
"image"
:
_LIMIT_IMAGE_PER_PROMPT
})
as
vllm_model
:
...
...
tests/models/registry.py
View file @
53076d70
...
...
@@ -192,6 +192,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"SolarForCausalLM"
:
_HfExamplesInfo
(
"upstage/solar-pro-preview-instruct"
),
"TeleChat2ForCausalLM"
:
_HfExamplesInfo
(
"Tele-AI/TeleChat2-3B"
,
trust_remote_code
=
True
),
"TeleFLMForCausalLM"
:
_HfExamplesInfo
(
"CofeAI/FLM-2-52B-Instruct-2407"
,
trust_remote_code
=
True
),
"XverseForCausalLM"
:
_HfExamplesInfo
(
"xverse/XVERSE-7B-Chat"
,
is_available_online
=
False
,
trust_remote_code
=
True
),
...
...
tests/plugins_tests/test_scheduler_plugins.py
View file @
53076d70
...
...
@@ -6,7 +6,7 @@ from vllm.core.scheduler import Scheduler
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.core.scheduler
import
Scheduler
as
V1Scheduler
from
vllm.v1.core.
sched.
scheduler
import
Scheduler
as
V1Scheduler
from
vllm.v1.engine.llm_engine
import
LLMEngine
as
V1LLMEngine
...
...
tests/spec_decode/e2e/conftest.py
View file @
53076d70
...
...
@@ -56,7 +56,7 @@ def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
def
maybe_assert_ngram_worker
(
llm
):
# Verify the proposer worker is ngram if ngram is specified.
if
(
llm
.
llm_engine
.
speculative_config
is
not
None
and
llm
.
llm_engine
.
speculative_config
.
ngram_prompt_lookup_max
>
0
):
and
llm
.
llm_engine
.
speculative_config
.
method
==
"ngram"
):
from
vllm.spec_decode.ngram_worker
import
NGramWorker
assert
isinstance
(
llm
.
llm_engine
.
model_executor
.
driver_worker
.
proposer_worker
,
...
...
tests/spec_decode/e2e/test_compatibility.py
View file @
53076d70
...
...
@@ -7,28 +7,39 @@ from vllm import SamplingParams
from
.conftest
import
get_output_from_llm_generator
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"meta-llama/Llama-3.2-1B-Instruct"
,
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
}])
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"meta-llama/Llama-3.2-1B-Instruct"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
# Speculative max model len > overridden max model len should raise.
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"max_model_len"
:
129
,
},
"max_model_len"
:
128
,
"speculative_max_model_len"
:
129
,
},
{
# Speculative max model len > draft max model len should raise.
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
"speculative_max_model_len"
:
2048
+
1
,
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"max_model_len"
:
2048
+
1
,
},
},
{
# Speculative max model len > target max model len should raise.
# https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18
"speculative_max_model_len"
:
131072
+
1
,
# https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"max_model_len"
:
131072
+
1
,
},
},
])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
...
...
tests/spec_decode/e2e/test_eagle_correctness.py
View file @
53076d70
...
...
@@ -57,8 +57,10 @@ PRECISION = "float32"
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
...
...
@@ -95,18 +97,19 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"
speculative_
model"
:
SPEC_MODEL
,
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"disable_logprobs
_during_spec_decoding
"
:
False
,
"disable_logprobs"
:
False
,
},
{
"speculative_model"
:
SPEC_MODEL
,
},
{
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"disable_logprobs
_during_spec_decoding
"
:
True
,
"disable_logprobs"
:
True
,
},
])
}
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
128
,
])
...
...
@@ -119,18 +122,19 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
,
logprobs
=
logprobs
,
prompt_logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
'disable_logprobs_during_spec_decoding'
])
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
,
logprobs
=
logprobs
,
prompt_logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
"speculative_config"
]
[
"disable_logprobs"
])
@
pytest
.
mark
.
parametrize
(
...
...
@@ -151,8 +155,10 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
...
...
@@ -193,8 +199,10 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
},
])
@
pytest
.
mark
.
parametrize
(
...
...
@@ -236,8 +244,10 @@ def test_eagle_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
k
,
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
k
,
},
}
# Try a range of num. speculative tokens
for
k
in
range
(
1
,
1
+
MAX_SPEC_TOKENS
)
...
...
@@ -277,12 +287,13 @@ def test_eagle_different_k(vllm_runner, common_llm_kwargs,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"speculative_disable_by_batch_size"
:
4
}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"disable_by_batch_size"
:
4
,
},
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
...
...
@@ -324,8 +335,10 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"yuhuili/EAGLE-llama2-chat-7B"
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"speculative_config"
:
{
"model"
:
"yuhuili/EAGLE-llama2-chat-7B"
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
},
])
@
pytest
.
mark
.
parametrize
(
...
...
@@ -372,8 +385,10 @@ def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"yuhuili/EAGLE-LLaMA3-Instruct-8B"
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"speculative_config"
:
{
"model"
:
"yuhuili/EAGLE-LLaMA3-Instruct-8B"
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
},
])
@
pytest
.
mark
.
parametrize
(
...
...
@@ -420,8 +435,10 @@ def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"yuhuili/EAGLE-Qwen2-7B-Instruct"
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"speculative_config"
:
{
"model"
:
"yuhuili/EAGLE-Qwen2-7B-Instruct"
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
},
])
@
pytest
.
mark
.
parametrize
(
...
...
Prev
1
2
3
4
5
6
7
8
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment