Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
711aa9d5
Commit
711aa9d5
authored
Jul 30, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.0' into v0.10.0-dev
parents
751c492c
6d8d0a24
Changes
519
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1294 additions
and
74 deletions
+1294
-74
tests/multimodal/test_utils.py
tests/multimodal/test_utils.py
+1
-1
tests/multimodal/test_video.py
tests/multimodal/test_video.py
+80
-23
tests/multimodal/utils.py
tests/multimodal/utils.py
+46
-0
tests/neuron/2_core/test_mistral.py
tests/neuron/2_core/test_mistral.py
+0
-1
tests/neuron/2_core/test_multi_lora.py
tests/neuron/2_core/test_multi_lora.py
+0
-2
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
...dd_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
+12
-16
tests/prefix_caching/test_disable_sliding_window.py
tests/prefix_caching/test_disable_sliding_window.py
+11
-11
tests/prefix_caching/test_prefix_caching.py
tests/prefix_caching/test_prefix_caching.py
+3
-3
tests/quantization/reference_mxfp4.py
tests/quantization/reference_mxfp4.py
+287
-0
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+2
-1
tests/quantization/test_gptq_dynamic.py
tests/quantization/test_gptq_dynamic.py
+1
-1
tests/quantization/test_modelopt.py
tests/quantization/test_modelopt.py
+91
-0
tests/quantization/test_register_quantization_config.py
tests/quantization/test_register_quantization_config.py
+1
-1
tests/quantization/untest_quark.py
tests/quantization/untest_quark.py
+178
-6
tests/reasoning/test_hunyuan_reasoning_parser.py
tests/reasoning/test_hunyuan_reasoning_parser.py
+173
-0
tests/reasoning/test_mistral_reasoning_parser.py
tests/reasoning/test_mistral_reasoning_parser.py
+341
-0
tests/reasoning/utils.py
tests/reasoning/utils.py
+59
-0
tests/samplers/test_ignore_eos.py
tests/samplers/test_ignore_eos.py
+1
-1
tests/samplers/test_logits_processor.py
tests/samplers/test_logits_processor.py
+5
-5
tests/samplers/test_logprobs.py
tests/samplers/test_logprobs.py
+2
-2
No files found.
Too many changes to show.
To preserve performance only
519 of 519+
files are displayed.
Plain diff
Email patch
tests/multimodal/test_utils.py
View file @
711aa9d5
...
...
@@ -41,7 +41,7 @@ TEST_IMAGE_URLS = [
TEST_VIDEO_URLS
=
[
"https://www.bogotobogo.com/python/OpenCV_Python/images/mean_shift_tracking/slow_traffic_small.mp4"
,
"https://
filesamples.com/samples/video/avi/sample_640x360
.avi"
,
"https://
github.com/opencv/opencv/raw/refs/tags/4.12.0/samples/data/vtest
.avi"
,
]
...
...
tests/multimodal/test_video.py
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
tempfile
from
pathlib
import
Path
import
numpy
as
np
import
numpy.typing
as
npt
import
pytest
from
PIL
import
Image
from
vllm
import
envs
from
vllm.assets.base
import
get_vllm_public_assets
from
vllm.assets.video
import
video_to_ndarrays
,
video_to_pil_images_list
from
vllm.multimodal.image
import
ImageMediaIO
from
vllm.multimodal.video
import
(
VIDEO_LOADER_REGISTRY
,
VideoLoader
,
VideoMediaIO
)
from
.utils
import
cosine_similarity
,
create_video_from_image
,
normalize_image
NUM_FRAMES
=
10
FAKE_OUTPUT_1
=
np
.
random
.
rand
(
NUM_FRAMES
,
1280
,
720
,
3
)
FAKE_OUTPUT_2
=
np
.
random
.
rand
(
NUM_FRAMES
,
1280
,
720
,
3
)
...
...
@@ -59,30 +67,79 @@ class Assert10Frames1FPSVideoLoader(VideoLoader):
return
FAKE_OUTPUT_2
def
test_video_media_io_kwargs
():
envs
.
VLLM_VIDEO_LOADER_BACKEND
=
"assert_10_frames_1_fps"
imageio
=
ImageMediaIO
()
def
test_video_media_io_kwargs
(
monkeypatch
:
pytest
.
MonkeyPatch
):
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_VIDEO_LOADER_BACKEND"
,
"assert_10_frames_1_fps"
)
imageio
=
ImageMediaIO
()
# Verify that different args pass/fail assertions as expected.
videoio
=
VideoMediaIO
(
imageio
,
**
{
"num_frames"
:
10
,
"fps"
:
1.0
})
_
=
videoio
.
load_bytes
(
b
"test"
)
videoio
=
VideoMediaIO
(
imageio
,
**
{
"num_frames"
:
10
,
"fps"
:
1.0
,
"not_used"
:
"not_used"
})
_
=
videoio
.
load_bytes
(
b
"test"
)
with
pytest
.
raises
(
AssertionError
,
match
=
"bad num_frames"
):
videoio
=
VideoMediaIO
(
imageio
,
**
{})
# Verify that different args pass/fail assertions as expected.
videoio
=
VideoMediaIO
(
imageio
,
**
{
"num_frames"
:
10
,
"fps"
:
1.0
})
_
=
videoio
.
load_bytes
(
b
"test"
)
with
pytest
.
raises
(
AssertionError
,
match
=
"bad num_frames"
):
videoio
=
VideoMediaIO
(
imageio
,
**
{
"num_frames"
:
9
,
"fps"
:
1.0
})
videoio
=
VideoMediaIO
(
imageio
,
**
{
"num_frames"
:
10
,
"fps"
:
1.0
,
"not_used"
:
"not_used"
})
_
=
videoio
.
load_bytes
(
b
"test"
)
with
pytest
.
raises
(
AssertionError
,
match
=
"bad fps"
):
videoio
=
VideoMediaIO
(
imageio
,
**
{
"num_frames"
:
10
,
"fps"
:
2.0
})
_
=
videoio
.
load_bytes
(
b
"test"
)
with
pytest
.
raises
(
AssertionError
,
match
=
"bad num_frames"
):
videoio
=
VideoMediaIO
(
imageio
,
**
{})
_
=
videoio
.
load_bytes
(
b
"test"
)
with
pytest
.
raises
(
AssertionError
,
match
=
"bad num_frames"
):
videoio
=
VideoMediaIO
(
imageio
,
**
{
"num_frames"
:
9
,
"fps"
:
1.0
})
_
=
videoio
.
load_bytes
(
b
"test"
)
with
pytest
.
raises
(
AssertionError
,
match
=
"bad fps"
):
videoio
=
VideoMediaIO
(
imageio
,
**
{
"num_frames"
:
10
,
"fps"
:
2.0
})
_
=
videoio
.
load_bytes
(
b
"test"
)
@
pytest
.
mark
.
parametrize
(
"is_color"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"fourcc, ext"
,
[(
"mp4v"
,
"mp4"
),
(
"XVID"
,
"avi"
)])
def
test_opencv_video_io_colorspace
(
is_color
:
bool
,
fourcc
:
str
,
ext
:
str
):
"""
Test all functions that use OpenCV for video I/O return RGB format.
Both RGB and grayscale videos are tested.
"""
image_path
=
get_vllm_public_assets
(
filename
=
"stop_sign.jpg"
,
s3_prefix
=
"vision_model_images"
)
image
=
Image
.
open
(
image_path
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
if
not
is_color
:
image_path
=
f
"
{
tmpdir
}
/test_grayscale_image.png"
image
=
image
.
convert
(
"L"
)
image
.
save
(
image_path
)
# Convert to gray RGB for comparison
image
=
image
.
convert
(
"RGB"
)
video_path
=
f
"
{
tmpdir
}
/test_RGB_video.
{
ext
}
"
create_video_from_image
(
image_path
,
video_path
,
num_frames
=
2
,
is_color
=
is_color
,
fourcc
=
fourcc
,
)
frames
=
video_to_ndarrays
(
video_path
)
for
frame
in
frames
:
sim
=
cosine_similarity
(
normalize_image
(
np
.
array
(
frame
)),
normalize_image
(
np
.
array
(
image
)))
assert
np
.
sum
(
np
.
isnan
(
sim
))
/
sim
.
size
<
0.001
assert
np
.
nanmean
(
sim
)
>
0.99
pil_frames
=
video_to_pil_images_list
(
video_path
)
for
frame
in
pil_frames
:
sim
=
cosine_similarity
(
normalize_image
(
np
.
array
(
frame
)),
normalize_image
(
np
.
array
(
image
)))
assert
np
.
sum
(
np
.
isnan
(
sim
))
/
sim
.
size
<
0.001
assert
np
.
nanmean
(
sim
)
>
0.99
io_frames
,
_
=
VideoMediaIO
(
ImageMediaIO
()).
load_file
(
Path
(
video_path
))
for
frame
in
io_frames
:
sim
=
cosine_similarity
(
normalize_image
(
np
.
array
(
frame
)),
normalize_image
(
np
.
array
(
image
)))
assert
np
.
sum
(
np
.
isnan
(
sim
))
/
sim
.
size
<
0.001
assert
np
.
nanmean
(
sim
)
>
0.99
tests/multimodal/utils.py
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
cv2
import
numpy
as
np
import
numpy.typing
as
npt
from
PIL
import
Image
...
...
@@ -31,3 +33,47 @@ def random_audio(
):
audio_len
=
rng
.
randint
(
min_len
,
max_len
)
return
rng
.
rand
(
audio_len
),
sr
def
create_video_from_image
(
image_path
:
str
,
video_path
:
str
,
num_frames
:
int
=
10
,
fps
:
float
=
1.0
,
is_color
:
bool
=
True
,
fourcc
:
str
=
"mp4v"
,
):
image
=
cv2
.
imread
(
image_path
)
if
not
is_color
:
# Convert to grayscale if is_color is False
image
=
cv2
.
cvtColor
(
image
,
cv2
.
COLOR_BGR2GRAY
)
height
,
width
=
image
.
shape
else
:
height
,
width
,
_
=
image
.
shape
video_writer
=
cv2
.
VideoWriter
(
video_path
,
cv2
.
VideoWriter_fourcc
(
*
fourcc
),
fps
,
(
width
,
height
),
isColor
=
is_color
,
)
for
_
in
range
(
num_frames
):
video_writer
.
write
(
image
)
video_writer
.
release
()
return
video_path
def
cosine_similarity
(
A
:
npt
.
NDArray
,
B
:
npt
.
NDArray
,
axis
:
int
=
-
1
)
->
npt
.
NDArray
:
"""Compute cosine similarity between two vectors."""
return
(
np
.
sum
(
A
*
B
,
axis
=
axis
)
/
(
np
.
linalg
.
norm
(
A
,
axis
=
axis
)
*
np
.
linalg
.
norm
(
B
,
axis
=
axis
)))
def
normalize_image
(
image
:
npt
.
NDArray
)
->
npt
.
NDArray
:
"""Normalize image to [0, 1] range."""
return
image
.
astype
(
np
.
float32
)
/
255.0
\ No newline at end of file
tests/neuron/2_core/test_mistral.py
View file @
711aa9d5
...
...
@@ -9,7 +9,6 @@ def test_mistral():
tensor_parallel_size
=
2
,
max_num_seqs
=
4
,
max_model_len
=
128
,
use_v2_block_manager
=
True
,
override_neuron_config
=
{
"sequence_parallel_enabled"
:
False
,
"skip_warmup"
:
True
...
...
tests/neuron/2_core/test_multi_lora.py
View file @
711aa9d5
...
...
@@ -14,7 +14,6 @@ def test_llama_single_lora():
tensor_parallel_size
=
2
,
max_num_seqs
=
4
,
max_model_len
=
512
,
use_v2_block_manager
=
True
,
override_neuron_config
=
{
"sequence_parallel_enabled"
:
False
,
"skip_warmup"
:
True
,
...
...
@@ -57,7 +56,6 @@ def test_llama_multiple_lora():
tensor_parallel_size
=
2
,
max_num_seqs
=
4
,
max_model_len
=
512
,
use_v2_block_manager
=
True
,
override_neuron_config
=
{
"sequence_parallel_enabled"
:
False
,
...
...
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
View file @
711aa9d5
...
...
@@ -8,14 +8,16 @@ import torch
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.pooler
import
Pooler
,
Pool
ingTyp
e
from
vllm.model_executor.layers.pooler
import
Dispatch
Pooler
,
Poole
r
from
vllm.model_executor.models.gemma2
import
Gemma2Model
from
vllm.model_executor.models.utils
import
WeightsMapper
,
maybe_prefix
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
class
MyGemma2Embedding
(
nn
.
Module
):
is_pooling_model
=
True
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model."
:
""
})
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
@@ -24,12 +26,13 @@ class MyGemma2Embedding(nn.Module):
self
.
model
=
Gemma2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
vllm_config
.
model_config
.
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
,
softmax
=
False
,
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
({
"encode"
:
Pooler
.
for_encode
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
),
})
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
...
...
@@ -54,13 +57,6 @@ class MyGemma2Embedding(nn.Module):
# Return all-zero embeddings
return
torch
.
zeros_like
(
hidden_states
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
weights
=
self
.
hf_to_vllm_mapper
.
apply
(
weights
)
...
...
tests/prefix_caching/test_disable_sliding_window.py
View file @
711aa9d5
...
...
@@ -27,25 +27,25 @@ MODEL_LEN_LEN = [
@
pytest
.
mark
.
parametrize
(
"model_len_len"
,
MODEL_LEN_LEN
)
def
test_disable_sliding_window
(
model_len_len
,
):
model
,
sliding_len
,
full_len
=
model_len_len
vllm_
disabled_
model
=
LLM
(
model
,
disable_sliding_window
=
True
)
vllm_
disabled_
model
.
generate
(
"Hi my name is"
)
model_config
=
vllm_
disabled_
model
.
llm_engine
.
model_config
disabled_
llm
=
LLM
(
model
,
disable_sliding_window
=
True
)
disabled_
llm
.
generate
(
"Hi my name is"
)
model_config
=
disabled_
llm
.
llm_engine
.
model_config
assert
model_config
.
max_model_len
==
sliding_len
,
(
"Max len expected to equal sliding_len of %s, but got %s"
,
sliding_len
,
model_config
.
max_model_len
)
del
vllm_
disabled_
model
del
disabled_
llm
cleanup_dist_env_and_memory
()
vllm_
enabled_
model
=
LLM
(
model
,
enforce_eager
=
True
,
disable_sliding_window
=
False
,
enable_prefix_caching
=
False
)
vllm_
enabled_
model
.
generate
(
"Hi my name is"
)
model_config
=
vllm_
enabled_
model
.
llm_engine
.
model_config
enabled_
llm
=
LLM
(
model
,
enforce_eager
=
True
,
disable_sliding_window
=
False
,
enable_prefix_caching
=
False
)
enabled_
llm
.
generate
(
"Hi my name is"
)
model_config
=
enabled_
llm
.
llm_engine
.
model_config
assert
model_config
.
max_model_len
==
full_len
,
(
"Max len expected to equal full_len of %s, but got %s"
,
full_len
,
model_config
.
max_model_len
)
del
vllm_
enabled_
model
del
enabled_
llm
cleanup_dist_env_and_memory
()
tests/prefix_caching/test_prefix_caching.py
View file @
711aa9d5
...
...
@@ -96,8 +96,8 @@ def test_mixed_requests(
# Run all the promopts
greedy_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
)
req_outputs
=
vllm_model
.
model
.
generate
(
example_prompts
,
greedy_params
)
req_outputs
=
vllm_model
.
llm
.
generate
(
example_prompts
,
greedy_params
)
# Verify number of cached tokens
for
i
in
range
(
len
(
req_outputs
)):
...
...
@@ -164,7 +164,7 @@ def test_fully_cached_prefill_needs_uncached_token(model):
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_seqs
=
max_num_batched_tokens
,
)
engine
:
LLMEngine
=
runner
.
model
.
llm_engine
engine
:
LLMEngine
=
runner
.
llm
.
llm_engine
scheduler
:
Scheduler
=
SchedulerProxy
(
engine
.
scheduler
[
0
])
# type: ignore
engine
.
scheduler
[
0
]
=
scheduler
...
...
tests/quantization/reference_mxfp4.py
0 → 100644
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
BFLOAT16_EXP_BIAS
=
127
BFLOAT16_MANTISSA_BITS
=
7
BFLOAT16_EXP_BITS
=
8
FLOAT16_EXP_BIAS
=
15
FLOAT16_MANTISSA_BITS
=
10
FLOAT16_EXP_BITS
=
5
FLOAT8_E8M0_MAX_EXP
=
127
FLOAT4_EXP_BIAS
=
1
FLOAT4_MANTISSA_BITS
=
1
FLOAT16_VAL_TO_ADD
=
(
1
<<
(
FLOAT16_MANTISSA_BITS
-
FLOAT4_MANTISSA_BITS
-
1
))
FLOAT16_SIGN_EXPONENT_MASK
=
((
(
1
<<
(
FLOAT16_EXP_BITS
+
1
))
-
1
)
<<
FLOAT16_MANTISSA_BITS
)
BFLOAT16_VAL_TO_ADD
=
(
1
<<
(
BFLOAT16_MANTISSA_BITS
-
FLOAT4_MANTISSA_BITS
-
1
))
BFLOAT16_SIGN_EXPONENT_MASK
=
((
(
1
<<
(
BFLOAT16_EXP_BITS
+
1
))
-
1
)
<<
BFLOAT16_MANTISSA_BITS
)
def
e8m0_to_half
(
scale
,
half_dtype
:
torch
.
dtype
):
assert
scale
.
dtype
==
torch
.
uint8
scale_exp
=
scale
.
to
(
torch
.
int16
)
-
127
# This can be implemented with bitwise operations in a proper kernel.
scale_half
=
2.0
**
(
scale_exp
.
to
(
torch
.
float
))
return
scale_half
.
to
(
half_dtype
)
def
upcast_fp4_to_fp16_or_bf16
(
val
,
float_dtype
:
torch
.
dtype
,
half_exp_bias
:
int
,
half_mantissa_bits
:
int
):
assert
val
.
dtype
==
torch
.
uint8
unpacked
=
torch
.
zeros
(
*
val
.
shape
[:
-
1
],
val
.
shape
[
-
1
]
*
2
,
dtype
=
torch
.
uint8
,
device
=
val
.
device
)
unpacked
[...,
1
::
2
]
=
(
val
>>
4
)
&
0x0F
# Extract high 4 bits.
unpacked
[...,
::
2
]
=
val
&
0x0F
# Extract low 4 bits.
# Takes one float4 values represented as b0000xxxx,
# and converts it to the corresponding float16 value.
sign
=
unpacked
>>
3
exp
=
(
unpacked
>>
1
)
&
3
new_mantissa
=
unpacked
&
1
# if exp == 0 and new_mantissa == 0:
# new_exp = 0
# else:
# new_exp = exp - FLOAT4_EXP_BIAS + FLOAT16_EXP_BIAS
# int8_t works with float16, but may overflow with bfloat16.
new_exp
=
exp
-
FLOAT4_EXP_BIAS
+
half_exp_bias
# Cast b0000 to 0. in fp16/bf16.
new_exp
=
new_exp
*
torch
.
logical_or
(
exp
>
0
,
new_mantissa
>
0
)
# Cast b0001 to 0.5 in fp16/bf16.
new_mantissa
=
torch
.
logical_and
(
new_mantissa
,
exp
>
0
)
new_mantissa
=
new_mantissa
.
to
(
torch
.
int32
)
new_exp
=
new_exp
.
to
(
torch
.
int32
)
sign
=
sign
.
to
(
torch
.
int32
)
qdq_val
=
(
sign
<<
15
)
+
(
new_exp
<<
half_mantissa_bits
)
+
(
new_mantissa
<<
(
half_mantissa_bits
-
1
))
assert
qdq_val
.
max
()
<=
65535
assert
qdq_val
.
min
()
>=
0
qdq_val
=
qdq_val
.
to
(
torch
.
uint16
)
result
=
qdq_val
.
view
(
float_dtype
)
return
result
def
dq_mxfp4_torch
(
x
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
float_dtype
:
torch
.
dtype
)
->
torch
.
Tensor
:
assert
x
.
dtype
==
torch
.
uint8
assert
scale
.
dtype
==
torch
.
uint8
if
float_dtype
==
torch
.
float16
:
half_exp_bias
=
FLOAT16_EXP_BIAS
half_mantissa_bits
=
FLOAT16_MANTISSA_BITS
elif
float_dtype
==
torch
.
bfloat16
:
half_exp_bias
=
BFLOAT16_EXP_BIAS
half_mantissa_bits
=
BFLOAT16_MANTISSA_BITS
scale_half
=
e8m0_to_half
(
scale
,
half_dtype
=
float_dtype
)
x_half
=
upcast_fp4_to_fp16_or_bf16
(
x
,
float_dtype
=
float_dtype
,
half_exp_bias
=
half_exp_bias
,
half_mantissa_bits
=
half_mantissa_bits
)
x_half
=
x_half
.
reshape
(
*
x_half
.
shape
[:
-
1
],
-
1
,
32
)
x_half
=
x_half
*
scale_half
[...,
None
]
x_half
=
x_half
.
reshape
(
*
x_half
.
shape
[:
-
2
],
-
1
)
return
x_half
def
fp16_to_fp4_simulate
(
val
,
half_mantissa_bits
:
int
,
half_exp_bits
:
int
,
half_exp_bias
:
int
):
# Casts an fp16/bf16 input to the restricted values of float4_e2m1,
# that is to say [0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0,
# -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0].
float_type
=
val
.
dtype
# "rshift_cuda" not implemented for 'UInt16'
val_view
=
val
.
view
(
torch
.
int16
)
#.to(torch.int32)
exp
=
val_view
>>
half_mantissa_bits
exp
=
exp
&
((
1
<<
half_exp_bits
)
-
1
)
exp
=
exp
.
view
(
torch
.
uint16
).
to
(
torch
.
int32
)
sign
=
(
val_view
>>
(
half_mantissa_bits
+
half_exp_bits
))
&
1
mantissa_last
=
(
val_view
>>
(
half_mantissa_bits
-
1
))
&
1
exp_unbias
=
exp
-
half_exp_bias
new_exp
=
exp_unbias
+
FLOAT4_EXP_BIAS
exp_shift
=
(
new_exp
<=
0
)
*
(
1
-
new_exp
)
# Typically 9.
# Take the min to prevent overflow on `uint16_t half`. This is the case for
# very small values, correctly mapped to `round_close`.
tail_bits
=
half_mantissa_bits
-
FLOAT4_MANTISSA_BITS
+
exp_shift
tail_bits
[
tail_bits
>=
16
]
=
16
mantissa_plus_one
=
val_view
&
((
1
<<
(
half_mantissa_bits
+
1
))
-
1
)
half
=
1
<<
(
tail_bits
-
1
)
tail
=
mantissa_plus_one
&
((
1
<<
tail_bits
)
-
1
)
round_close
=
(
tail
<
half
)
# round towards 0
round_away
=
(
tail
>
half
)
# round away from 0
tie
=
tail
==
half
new_mantissa_close
=
torch
.
zeros
(
val
.
shape
,
device
=
val
.
device
,
dtype
=
torch
.
bool
)
new_exp_close
=
torch
.
zeros
(
val
.
shape
,
device
=
val
.
device
,
dtype
=
torch
.
uint16
)
new_mantissa_away
=
torch
.
zeros
(
val
.
shape
,
device
=
val
.
device
,
dtype
=
torch
.
bool
)
new_exp_away
=
torch
.
zeros
(
val
.
shape
,
device
=
val
.
device
,
dtype
=
torch
.
uint16
)
new_exp_tie
=
torch
.
zeros
(
val
.
shape
,
device
=
val
.
device
,
dtype
=
torch
.
uint16
)
# 1. round down
# if new_exp == 0: # case [0.5, 0.749999]
# new_mantissa = 0
# elif new_exp < 0: # case [0, 0.24999]
# new_mantissa = 0
# else:
# new_mantissa = mantissa_last
new_mantissa_close
=
(
new_exp
>
0
)
*
mantissa_last
new_exp_close
=
exp
# # 2. round up
# if new_exp <= 0: # case [0.250001, 0.499999] and [0.75001, 0.99999]
# new_mantissa = 0
# new_exp += 1
# elif mantissa_last == 0:
# new_mantissa = 1
# else:
# new_mantissa = 0
# new_exp += 1
new_mantissa_away
=
torch
.
logical_and
(
new_exp
>
0
,
mantissa_last
==
0
)
new_exp_away
=
exp
+
torch
.
logical_or
(
new_exp
<=
0
,
mantissa_last
==
1
)
# # 3. tie
# 0.25 -> 0. (handled by `exp > (half_exp_bias - 2)`)
# 0.75 -> 1.
# 1.25 -> 1.
# 1.75 -> 2.
# 2.5 -> 2.
# 3.5 -> 4.
# 5. -> 4.
new_exp_tie
=
(
exp
>
(
half_exp_bias
-
2
))
*
(
exp
+
(
mantissa_last
==
1
))
# Gather round up, round down and tie.
new_exp
=
round_away
*
new_exp_away
\
+
round_close
*
new_exp_close
\
+
tie
*
new_exp_tie
new_mantissa
=
round_away
*
new_mantissa_away
\
+
round_close
*
new_mantissa_close
# if new_exp > 3:
# new_mantissa = 1
new_mantissa
=
new_mantissa
+
(
new_exp
>
(
2
+
half_exp_bias
))
*
(
new_mantissa
==
0
)
# Clamp the exponent to acceptable values.
new_exp
=
(
new_exp
>=
(
half_exp_bias
-
2
))
*
torch
.
clamp
(
new_exp
,
half_exp_bias
-
2
,
half_exp_bias
+
2
)
sign
=
sign
.
to
(
torch
.
int32
)
new_mantissa
=
new_mantissa
.
to
(
torch
.
int32
)
qdq_val
=
(
sign
<<
15
)
+
(
new_exp
<<
half_mantissa_bits
)
+
(
new_mantissa
<<
(
half_mantissa_bits
-
1
))
assert
qdq_val
.
max
()
<=
65535
assert
qdq_val
.
min
()
>=
0
assert
qdq_val
.
dtype
==
torch
.
int32
qdq_val
=
qdq_val
.
to
(
torch
.
uint16
)
result
=
qdq_val
.
view
(
float_type
)
return
result
def
qdq_mxfp4_torch
(
x
:
torch
.
Tensor
,
scale_calculation_mode
:
str
=
"even"
)
->
torch
.
Tensor
:
half_dtype
=
x
.
dtype
if
half_dtype
==
torch
.
float16
:
half_mantissa_bits
=
FLOAT16_MANTISSA_BITS
half_exp_bits
=
FLOAT16_EXP_BITS
half_exp_bias
=
FLOAT16_EXP_BIAS
val_to_add
=
FLOAT16_VAL_TO_ADD
sign_exponent_mask
=
FLOAT16_SIGN_EXPONENT_MASK
elif
half_dtype
==
torch
.
bfloat16
:
half_mantissa_bits
=
BFLOAT16_MANTISSA_BITS
half_exp_bits
=
BFLOAT16_EXP_BITS
half_exp_bias
=
BFLOAT16_EXP_BIAS
val_to_add
=
BFLOAT16_VAL_TO_ADD
sign_exponent_mask
=
BFLOAT16_SIGN_EXPONENT_MASK
else
:
raise
ValueError
(
"not implemented"
)
x
=
x
.
reshape
(
*
x
.
shape
[:
-
1
],
-
1
,
32
)
block_max
=
torch
.
max
(
torch
.
abs
(
x
),
dim
=-
1
).
values
block_max
=
block_max
.
view
(
torch
.
uint16
).
to
(
torch
.
int32
)
block_max_uint
=
torch
.
bitwise_and
(
block_max
+
val_to_add
,
sign_exponent_mask
)
assert
block_max_uint
.
max
()
<=
65535
assert
block_max_uint
.
min
()
>=
0
assert
block_max_uint
.
dtype
==
torch
.
int32
block_max_uint
=
block_max_uint
.
to
(
torch
.
uint16
)
block_max
=
block_max_uint
.
view
(
half_dtype
)
scale_exp
=
FLOAT8_E8M0_MAX_EXP
+
torch
.
floor
(
torch
.
log2
(
block_max
)).
to
(
torch
.
int32
)
-
2
scale_exp
=
torch
.
clamp
(
scale_exp
,
0
,
2
*
FLOAT8_E8M0_MAX_EXP
)
scale
=
2.0
**
(
scale_exp
-
FLOAT8_E8M0_MAX_EXP
)
scale
=
scale
.
to
(
half_dtype
)
x
=
x
/
scale
[...,
None
]
x_fp4
=
fp16_to_fp4_simulate
(
x
,
half_exp_bits
=
half_exp_bits
,
half_mantissa_bits
=
half_mantissa_bits
,
half_exp_bias
=
half_exp_bias
)
x_fp4
=
x_fp4
*
scale
[...,
None
]
return
x_fp4
.
reshape
(
*
x_fp4
.
shape
[:
-
2
],
-
1
)
tests/quantization/test_compressed_tensors.py
View file @
711aa9d5
...
...
@@ -48,7 +48,8 @@ def use_v0_only(monkeypatch):
"""
This module relies on V0 internals, so set VLLM_USE_V1=0.
"""
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
if
not
current_platform
.
is_cpu
():
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
@
pytest
.
mark
.
parametrize
(
...
...
tests/quantization/test_gptq_dynamic.py
View file @
711aa9d5
...
...
@@ -41,7 +41,7 @@ def test_gptq_with_dynamic(vllm_runner, model_id: str, use_marlin_kernel: bool,
linear_method_cls
=
GPTQMarlinLinearMethod
if
use_marlin_kernel
else
(
GPTQLinearMethod
)
for
name
,
submodule
in
(
vllm_model
.
model
.
llm_engine
.
model_executor
.
for
name
,
submodule
in
(
vllm_model
.
llm
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
.
named_modules
()):
if
name
==
"lm_head"
:
assert
isinstance
(
submodule
.
quant_method
,
linear_method_cls
)
...
...
tests/quantization/test_modelopt.py
0 → 100644
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test ModelOpt quantization method setup and weight loading.
Run `pytest tests/quantization/test_modelopt.py`.
"""
import
os
import
pytest
import
torch
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm.platforms
import
current_platform
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
"""
This module relies on V0 internals, so set VLLM_USE_V1=0.
"""
if
not
current_platform
.
is_cpu
():
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"modelopt"
),
reason
=
"ModelOpt FP8 is not supported on this GPU type."
)
def
test_modelopt_fp8_checkpoint_setup
(
vllm_runner
):
"""Test ModelOpt FP8 checkpoint loading and structure validation."""
# TODO: provide a small publically available test checkpoint
model_path
=
(
"/home/scratch.omniml_data_1/zhiyu/ckpts/test_ckpts/"
"TinyLlama-1.1B-Chat-v1.0-fp8-0710"
)
# Skip test if checkpoint doesn't exist
if
not
os
.
path
.
exists
(
model_path
):
pytest
.
skip
(
f
"Test checkpoint not found at
{
model_path
}
. "
"This test requires a local ModelOpt FP8 checkpoint."
)
with
vllm_runner
(
model_path
,
quantization
=
"modelopt"
,
enforce_eager
=
True
)
as
llm
:
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
o_proj
=
layer
.
self_attn
.
o_proj
gate_up_proj
=
layer
.
mlp
.
gate_up_proj
down_proj
=
layer
.
mlp
.
down_proj
# Check that ModelOpt quantization method is properly applied
from
vllm.model_executor.layers.quantization.modelopt
import
(
ModelOptFp8LinearMethod
)
assert
isinstance
(
qkv_proj
.
quant_method
,
ModelOptFp8LinearMethod
)
assert
isinstance
(
o_proj
.
quant_method
,
ModelOptFp8LinearMethod
)
assert
isinstance
(
gate_up_proj
.
quant_method
,
ModelOptFp8LinearMethod
)
assert
isinstance
(
down_proj
.
quant_method
,
ModelOptFp8LinearMethod
)
# Check weight dtype is FP8
assert
qkv_proj
.
weight
.
dtype
==
torch
.
float8_e4m3fn
assert
o_proj
.
weight
.
dtype
==
torch
.
float8_e4m3fn
assert
gate_up_proj
.
weight
.
dtype
==
torch
.
float8_e4m3fn
assert
down_proj
.
weight
.
dtype
==
torch
.
float8_e4m3fn
# Check scales are present and have correct dtype
assert
hasattr
(
qkv_proj
,
'weight_scale'
)
assert
hasattr
(
qkv_proj
,
'input_scale'
)
assert
qkv_proj
.
weight_scale
.
dtype
==
torch
.
float32
assert
qkv_proj
.
input_scale
.
dtype
==
torch
.
float32
assert
hasattr
(
o_proj
,
'weight_scale'
)
assert
hasattr
(
o_proj
,
'input_scale'
)
assert
o_proj
.
weight_scale
.
dtype
==
torch
.
float32
assert
o_proj
.
input_scale
.
dtype
==
torch
.
float32
assert
hasattr
(
gate_up_proj
,
'weight_scale'
)
assert
hasattr
(
gate_up_proj
,
'input_scale'
)
assert
gate_up_proj
.
weight_scale
.
dtype
==
torch
.
float32
assert
gate_up_proj
.
input_scale
.
dtype
==
torch
.
float32
assert
hasattr
(
down_proj
,
'weight_scale'
)
assert
hasattr
(
down_proj
,
'input_scale'
)
assert
down_proj
.
weight_scale
.
dtype
==
torch
.
float32
assert
down_proj
.
input_scale
.
dtype
==
torch
.
float32
llm
.
apply_model
(
check_model
)
# Run a simple generation test to ensure the model works
output
=
llm
.
generate_greedy
([
"Hello my name is"
],
max_tokens
=
20
)
assert
output
print
(
f
"ModelOpt FP8 output:
{
output
}
"
)
tests/quantization/test_register_quantization_config.py
View file @
711aa9d5
...
...
@@ -113,7 +113,7 @@ def test_custom_quant(vllm_runner, model, monkeypatch):
quantization
=
"custom_quant"
,
enforce_eager
=
True
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
model
=
llm
.
llm
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
...
...
tests/quantization/untest_quark.py
View file @
711aa9d5
...
...
@@ -3,15 +3,45 @@
"""Test model set-up and weight loading for quark-quantized models.
Run `pytest tests/quantization/test_quark.py`.
See also `tests/kernels/moe/test_mxfp4_moe.py`.
"""
import
importlib
import
importlib.metadata
import
os
from
dataclasses
import
dataclass
import
huggingface_hub
import
lm_eval
import
pytest
import
torch
from
packaging
import
version
from
vllm.model_executor.layers.quantization.quark.quark
import
(
# noqa: E501
QuarkLinearMethod
,
QuarkW8A8Fp8
,
QuarkW8A8Int8
)
from
vllm.platforms
import
current_platform
from
.reference_mxfp4
import
dq_mxfp4_torch
,
qdq_mxfp4_torch
from
..utils
import
models_path_prefix
QUARK_MXFP4_AVAILABLE
=
importlib
.
util
.
find_spec
(
"quark"
)
is
not
None
and
version
.
parse
(
importlib
.
metadata
.
version
(
"amd-quark"
))
>=
version
.
parse
(
'0.8.99'
)
if
QUARK_MXFP4_AVAILABLE
:
from
quark.torch.export.nn.modules.realquantizer
import
(
StaticScaledRealQuantizer
)
from
quark.torch.kernel
import
mx
as
mx_kernel
from
quark.torch.quantization.config.config
import
FP4PerGroupSpec
try
:
huggingface_hub
.
list_repo_refs
(
"amd/Llama-3.3-70B-Instruct-WMXFP4-AMXFP4-KVFP8-Scale-UINT8-SQ"
)
HF_HUB_AMD_ORG_ACCESS
=
True
except
huggingface_hub
.
errors
.
RepositoryNotFoundError
:
HF_HUB_AMD_ORG_ACCESS
=
False
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
...
...
@@ -24,7 +54,7 @@ def use_v0_only(monkeypatch):
@
pytest
.
mark
.
parametrize
(
'kv_cache_dtype'
,
[
'auto'
,
'fp8'
])
@
pytest
.
mark
.
parametrize
(
'tp'
,
[
1
])
def
test_quark_fp8_w_per_tensor_a_per_tensor
(
vllm_runner
,
kv_cache_dtype
,
tp
):
model_path
=
"amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
model_path
=
os
.
path
.
join
(
models_path_prefix
,
"amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
)
with
vllm_runner
(
model_path
,
kv_cache_dtype
=
kv_cache_dtype
,
tensor_parallel_size
=
tp
)
as
llm
:
...
...
@@ -68,8 +98,8 @@ def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
def
test_quark_fp8_parity
(
vllm_runner
):
quark_model_id
=
"amd-quark/llama-tiny-fp8-quark-quant-method"
fp8_model_id
=
"amd-quark/llama-tiny-fp8-quant-method"
quark_model_id
=
os
.
path
.
join
(
models_path_prefix
,
"amd-quark/llama-tiny-fp8-quark-quant-method"
)
fp8_model_id
=
os
.
path
.
join
(
models_path_prefix
,
"amd-quark/llama-tiny-fp8-quant-method"
)
llm_kwargs
=
{
"tensor_parallel_size"
:
1
,
...
...
@@ -78,15 +108,157 @@ def test_quark_fp8_parity(vllm_runner):
}
with
(
vllm_runner
(
quark_model_id
,
**
llm_kwargs
)
as
quark_handle
,
vllm_runner
(
fp8_model_id
,
**
llm_kwargs
)
as
fp8_handle
):
quark_model
=
(
quark_handle
.
model
.
llm_engine
.
model_executor
.
quark_model
=
(
quark_handle
.
llm
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
)
quark_state_dict
=
quark_model
.
state_dict
()
fp8_model
=
(
fp8_handle
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
fp8_model
=
(
fp8_handle
.
llm
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
)
fp8_state_dict
=
fp8_model
.
state_dict
()
assert
fp8_state_dict
.
keys
()
==
quark_state_dict
.
keys
()
for
key
in
fp8_state_dict
:
assert
torch
.
equal
(
fp8_state_dict
[
key
],
quark_state_dict
[
key
])
\ No newline at end of file
assert
torch
.
equal
(
fp8_state_dict
[
key
],
quark_state_dict
[
key
])
@
dataclass
class
ModelCase
:
model_id
:
str
tp
:
int
@
dataclass
class
GSM8KAccuracyTestConfig
:
model_name
:
str
excepted_value
:
float
def
get_model_args
(
self
)
->
str
:
return
(
f
"pretrained=
{
self
.
model_name
}
,"
"dtype=auto,add_bos_token=True,tensor_parallel_size=8,gpu_memory_utilization=0.7,max_model_len=38768"
)
ACCURACY_CONFIGS
=
[
# Private model.
GSM8KAccuracyTestConfig
(
model_name
=
os
.
path
.
join
(
models_path_prefix
,
"amd/DeepSeek-R1-WMXFP4-AMXFP4-Scale-UINT8-MoE-Quant"
),
excepted_value
=
0.96
),
]
@
pytest
.
mark
.
parametrize
(
"config"
,
ACCURACY_CONFIGS
)
@
pytest
.
mark
.
skipif
(
not
QUARK_MXFP4_AVAILABLE
,
reason
=
"amd-quark>=0.9 is not available"
)
@
pytest
.
mark
.
skipif
(
not
HF_HUB_AMD_ORG_ACCESS
,
reason
=
"Read access to huggingface.co/amd is required for this test."
)
def
test_mxfp4_gsm8k_correctness
(
config
:
GSM8KAccuracyTestConfig
):
if
torch
.
cuda
.
device_count
()
<
8
:
pytest
.
skip
(
f
"This test requires >=8 gpus, got only
{
torch
.
cuda
.
device_count
()
}
"
)
task
=
"gsm8k"
rtol
=
0.03
os
.
environ
[
"VLLM_USE_TRITON_FLASH_ATTN"
]
=
"0"
results
=
lm_eval
.
simple_evaluate
(
model
=
"vllm"
,
model_args
=
config
.
get_model_args
(),
tasks
=
task
,
batch_size
=
64
,
num_fewshot
=
8
,
)
EXPECTED_VALUE
=
config
.
excepted_value
measured_value
=
results
[
"results"
][
task
][
"exact_match,strict-match"
]
assert
(
measured_value
-
rtol
<
EXPECTED_VALUE
and
measured_value
+
rtol
>
EXPECTED_VALUE
),
f
"Expected:
{
EXPECTED_VALUE
}
| Measured:
{
measured_value
}
"
del
os
.
environ
[
"VLLM_USE_TRITON_FLASH_ATTN"
]
@
pytest
.
mark
.
skipif
(
not
QUARK_MXFP4_AVAILABLE
,
reason
=
"amd-quark>=0.9 is not available"
)
@
pytest
.
mark
.
parametrize
(
"float_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"scalings"
,
[[
2.3
,
0.03
,
7.3
,
0.1
,
0.004
,
17.3
,
1e4
,
1e-4
]])
def
test_mxfp4_fused_qdq_match_quark
(
float_dtype
:
torch
.
dtype
,
scalings
:
list
[
int
]):
torch
.
manual_seed
(
0
)
hidden_size
=
64
*
32
inp
=
(
torch
.
rand
(
1
,
hidden_size
,
dtype
=
float_dtype
,
device
=
"cuda"
)
-
0.5
)
*
2
for
i
in
range
(
hidden_size
//
32
):
inp
[:,
i
*
32
:(
i
+
1
)
*
32
]
=
inp
[:,
i
*
32
:(
i
+
1
)
*
32
]
*
scalings
[
i
%
len
(
scalings
)]
inp_kernel
=
inp
.
clone
()
inp_kernel_clone
=
inp_kernel
.
clone
()
res_hip
=
mx_kernel
.
qdq_mxfp4_hip
(
inp_kernel_clone
,
"even"
)
res_torch
=
qdq_mxfp4_torch
(
inp_kernel
,
"even"
)
for
i
in
range
(
hidden_size
//
32
):
assert
torch
.
all
(
torch
.
isfinite
(
res_hip
[:,
i
*
32
:(
i
+
1
)
*
32
]))
assert
torch
.
all
(
torch
.
isfinite
(
res_torch
[:,
i
*
32
:(
i
+
1
)
*
32
]))
torch
.
testing
.
assert_close
(
res_hip
[:,
i
*
32
:(
i
+
1
)
*
32
],
res_torch
[:,
i
*
32
:(
i
+
1
)
*
32
])
@
pytest
.
mark
.
skipif
(
not
QUARK_MXFP4_AVAILABLE
,
reason
=
"amd-quark>=0.9 is not available"
)
@
pytest
.
mark
.
parametrize
(
"float_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"scalings"
,
[[
2.3
,
0.03
,
7.3
,
0.1
,
0.004
,
17.3
,
1e4
,
1e-4
]])
def
test_mxfp4_dequant_kernel_match_quark
(
float_dtype
:
torch
.
dtype
,
scalings
:
list
[
int
]):
qspec
=
FP4PerGroupSpec
(
ch_axis
=-
1
,
group_size
=
32
,
scale_format
=
"e8m0"
,
scale_calculation_mode
=
"even"
,
is_dynamic
=
False
,
).
to_quantization_spec
()
weight_quantizer
=
StaticScaledRealQuantizer
(
qspec
=
qspec
,
quantizer
=
None
,
reorder
=
False
,
real_quantized
=
True
,
float_dtype
=
float_dtype
,
device
=
"cuda"
,
)
observer
=
qspec
.
observer_cls
(
qspec
,
device
=
"cuda"
)
hidden_size
=
512
shape
=
(
11008
,
hidden_size
)
w
=
(
torch
.
rand
(
shape
,
device
=
"cuda"
,
dtype
=
float_dtype
)
-
0.5
)
*
2
# Make it so that different groups have different scales.
for
i
in
range
(
hidden_size
//
32
):
w
[:,
i
*
32
:(
i
+
1
)
*
32
]
=
w
[:,
i
*
32
:(
i
+
1
)
*
32
]
*
scalings
[
i
%
len
(
scalings
)]
observer
(
w
)
scale
,
_
=
observer
.
_calculate_qparams
()
weight_quantizer
.
scale
=
scale
w_mxfp4
=
weight_quantizer
.
to_real_quantize_params
(
w
).
to
(
"cuda"
)
weight_quantizer
.
maybe_convert_and_transpose_scale
()
scale
=
weight_quantizer
.
scale
out_hip
=
mx_kernel
.
dq_mxfp4_hip
(
w_mxfp4
,
scale
,
float_dtype
)
out_torch
=
dq_mxfp4_torch
(
w_mxfp4
,
scale
,
float_dtype
)
assert
torch
.
equal
(
out_hip
,
out_torch
)
tests/reasoning/test_hunyuan_reasoning_parser.py
0 → 100644
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
from
transformers
import
AutoTokenizer
from
tests.reasoning.utils
import
run_reasoning_extraction
from
vllm.reasoning
import
ReasoningParser
,
ReasoningParserManager
parser_name
=
"hunyuan_a13b"
START_REASONING
=
"<think>
\n
"
START_RESPONSE
=
"
\n
</think>
\n
<answer>
\n
"
END_RESPONSE
=
"
\n
</answer>"
NO_REASONING_QUICK_THROUGHT
=
{
"output"
:
f
"
{
START_REASONING
}{
START_RESPONSE
}
This is the rest
{
END_RESPONSE
}
"
,
#noqa: E501
"reasoning_content"
:
None
,
"content"
:
"This is the rest"
,
}
SIMPLE_REASONING
=
{
"output"
:
f
"
{
START_REASONING
}
This is a reasoning section
{
START_RESPONSE
}
This is the rest
{
END_RESPONSE
}
"
,
#noqa: E501
"reasoning_content"
:
"This is a reasoning section"
,
"content"
:
"This is the rest"
,
}
COMPLETE_REASONING
=
{
"output"
:
f
"
{
START_REASONING
}
This is a reasoning section
{
START_RESPONSE
}
"
,
"reasoning_content"
:
"This is a reasoning section"
,
"content"
:
None
,
}
COMPLETE_REASONING_WITH_SYMBOL
=
{
"output"
:
f
"
{
START_REASONING
}
This is a reasoning section!
{
START_RESPONSE
}
"
,
"reasoning_content"
:
"This is a reasoning section!"
,
"content"
:
None
,
}
NO_REASONING
=
{
"output"
:
"This is content"
,
"reasoning_content"
:
None
,
"content"
:
"This is content"
,
}
MULTIPLE_LINES
=
{
"output"
:
f
"
{
START_REASONING
}
This
\n
That
{
START_RESPONSE
}
This is the rest
\n
That"
,
"reasoning_content"
:
"This
\n
That"
,
"content"
:
"This is the rest
\n
That"
,
}
REASONING_WITH_THINK
=
{
"output"
:
f
"
{
START_REASONING
}
This is a reasoning section
{
START_RESPONSE
}
This is the rest"
,
#noqa: E501
"reasoning_content"
:
"This is a reasoning section"
,
"content"
:
"This is the rest"
,
}
COMPLETE_REASONING_WITH_THINK
=
{
"output"
:
f
"
{
START_REASONING
}
This is a reasoning section
{
START_RESPONSE
}
"
,
"reasoning_content"
:
"This is a reasoning section"
,
"content"
:
None
,
}
MULTIPLE_LINES_WITH_THINK
=
{
"output"
:
f
"
{
START_REASONING
}
This
\n
That
{
START_RESPONSE
}
This is the rest
\n
That"
,
"reasoning_content"
:
"This
\n
That"
,
"content"
:
"This is the rest
\n
That"
,
}
TEST_CASES
=
[
pytest
.
param
(
False
,
SIMPLE_REASONING
,
id
=
"simple_reasoning"
,
),
pytest
.
param
(
False
,
COMPLETE_REASONING
,
id
=
"complete_reasoning"
,
),
pytest
.
param
(
False
,
COMPLETE_REASONING_WITH_SYMBOL
,
id
=
"complete_reasoning_with_symbol"
,
),
pytest
.
param
(
False
,
NO_REASONING
,
id
=
"no_reasoning"
,
),
pytest
.
param
(
False
,
NO_REASONING_QUICK_THROUGHT
,
id
=
"no_reasoning_quick"
),
pytest
.
param
(
False
,
MULTIPLE_LINES
,
id
=
"multiple_lines"
,
),
pytest
.
param
(
False
,
REASONING_WITH_THINK
,
id
=
"reasoning_with_think"
,
),
pytest
.
param
(
False
,
COMPLETE_REASONING_WITH_THINK
,
id
=
"complete_reasoning_with_think"
,
),
pytest
.
param
(
False
,
MULTIPLE_LINES_WITH_THINK
,
id
=
"multiple_lines_with_think"
,
),
pytest
.
param
(
True
,
SIMPLE_REASONING
,
id
=
"simple_reasoning_streaming"
,
),
pytest
.
param
(
True
,
COMPLETE_REASONING
,
id
=
"complete_reasoning_streaming"
,
),
pytest
.
param
(
True
,
NO_REASONING
,
id
=
"no_reasoning_streaming"
,
),
pytest
.
param
(
True
,
NO_REASONING_QUICK_THROUGHT
,
id
=
"no_reasoning_quick_stream"
),
pytest
.
param
(
True
,
MULTIPLE_LINES
,
id
=
"multiple_lines_streaming"
,
),
pytest
.
param
(
True
,
REASONING_WITH_THINK
,
id
=
"reasoning_with_think_streaming"
,
),
pytest
.
param
(
True
,
COMPLETE_REASONING_WITH_THINK
,
id
=
"complete_reasoning_with_think_streaming"
,
),
pytest
.
param
(
True
,
MULTIPLE_LINES_WITH_THINK
,
id
=
"multiple_lines_with_think_streaming"
,
),
]
# Global tokenizer initialization to avoid repeated loading
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"tencent/Hunyuan-A13B-Instruct"
,
trust_remote_code
=
True
)
@
pytest
.
mark
.
parametrize
(
"streaming, param_dict"
,
TEST_CASES
)
def
test_reasoning
(
streaming
:
bool
,
param_dict
:
dict
,
):
output
=
tokenizer
.
tokenize
(
param_dict
[
"output"
])
# decode everything to tokens
output_tokens
:
list
[
str
]
=
[
tokenizer
.
convert_tokens_to_string
([
token
])
for
token
in
output
]
parser
:
ReasoningParser
=
ReasoningParserManager
.
get_reasoning_parser
(
parser_name
)(
tokenizer
)
reasoning
,
content
=
run_reasoning_extraction
(
parser
,
output_tokens
,
streaming
=
streaming
)
assert
reasoning
==
param_dict
[
"reasoning_content"
]
assert
content
==
param_dict
[
"content"
]
tests/reasoning/test_mistral_reasoning_parser.py
0 → 100644
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
from
mistral_common.tokens.tokenizers.base
import
SpecialTokens
from
mistral_common.tokens.tokenizers.tekken
import
(
SpecialTokenInfo
,
Tekkenizer
)
from
tests.reasoning.utils
import
run_reasoning_extraction_mistral
from
vllm.reasoning
import
ReasoningParser
,
ReasoningParserManager
from
vllm.transformers_utils.tokenizers.mistral
import
MistralTokenizer
parser_name
=
"mistral"
@
pytest
.
fixture
(
scope
=
"module"
)
def
mistral_tokenizer
():
# TODO(Julien): upon model release change to a tokenizer already configured.
# =================================================================
mistral_tokenizer
=
MistralTokenizer
.
from_pretrained
(
"mistralai/Devstral-Small-2507"
)
assert
isinstance
(
mistral_tokenizer
.
tokenizer
,
Tekkenizer
)
# Add think special tokens to the tokenizer
mistral_tokenizer
.
tokenizer
.
_all_special_tokens
[
35
]
=
SpecialTokenInfo
(
rank
=
35
,
is_control
=
True
,
token_str
=
SpecialTokens
.
begin_think
.
value
)
mistral_tokenizer
.
tokenizer
.
_all_special_tokens
[
36
]
=
SpecialTokenInfo
(
rank
=
36
,
is_control
=
True
,
token_str
=
SpecialTokens
.
end_think
.
value
)
mistral_tokenizer
.
tokenizer
.
_special_tokens_reverse_vocab
=
{
k
:
v
for
k
,
v
in
mistral_tokenizer
.
tokenizer
.
_special_tokens_reverse_vocab
.
items
()
if
v
not
in
{
35
,
36
}
}
mistral_tokenizer
.
tokenizer
.
_special_tokens_reverse_vocab
[
SpecialTokens
.
begin_think
.
value
]
=
35
mistral_tokenizer
.
tokenizer
.
_special_tokens_reverse_vocab
[
SpecialTokens
.
end_think
.
value
]
=
36
mistral_tokenizer
.
instruct
.
BEGIN_THINK
=
35
mistral_tokenizer
.
instruct
.
END_THINK
=
36
# =================================================================
return
mistral_tokenizer
SIMPLE_REASONING
=
{
"output"
:
"This is a reasoning section[/THINK]This is the rest"
,
"reasoning_content"
:
"This is a reasoning section"
,
"content"
:
"This is the rest"
,
"is_reasoning_end"
:
True
,
}
COMPLETE_REASONING
=
{
"output"
:
"This is a reasoning section[/THINK]"
,
"reasoning_content"
:
"This is a reasoning section"
,
"content"
:
None
,
"is_reasoning_end"
:
True
,
}
NO_CONTENT
=
{
"output"
:
"This is content"
,
"reasoning_content"
:
"This is content"
,
"content"
:
None
,
"is_reasoning_end"
:
False
,
}
NO_REASONING_STREAMING
=
{
"output"
:
"This is a reasoning section"
,
"reasoning_content"
:
"This is a reasoning section"
,
"content"
:
None
,
"is_reasoning_end"
:
False
,
}
MULTIPLE_LINES
=
{
"output"
:
"This
\n
That[/THINK]This is the rest
\n
That"
,
"reasoning_content"
:
"This
\n
That"
,
"content"
:
"This is the rest
\n
That"
,
"is_reasoning_end"
:
True
,
}
SHORTEST_REASONING_NO_STREAMING
=
{
"output"
:
"[/THINK]This is the rest"
,
"reasoning_content"
:
""
,
"content"
:
"This is the rest"
,
"is_reasoning_end"
:
True
,
}
SHORTEST_REASONING
=
{
"output"
:
"[/THINK]This is the rest"
,
"reasoning_content"
:
None
,
"content"
:
"This is the rest"
,
"is_reasoning_end"
:
True
,
}
REASONING_WITH_THINK
=
{
"output"
:
"[THINK]This is a reasoning section[/THINK]This is the rest"
,
"reasoning_content"
:
"This is a reasoning section"
,
"content"
:
"This is the rest"
,
"is_reasoning_end"
:
True
,
}
COMPLETE_REASONING_WITH_THINK
=
{
"output"
:
"[THINK]This is a reasoning section[/THINK]"
,
"reasoning_content"
:
"This is a reasoning section"
,
"content"
:
None
,
"is_reasoning_end"
:
True
,
}
MULTIPLE_LINES_WITH_THINK
=
{
"output"
:
"[THINK]This
\n
That[/THINK]This is the rest
\n
That"
,
"reasoning_content"
:
"This
\n
That"
,
"content"
:
"This is the rest
\n
That"
,
"is_reasoning_end"
:
True
,
}
SHORTEST_REASONING_NO_STREAMING_WITH_THINK
=
{
"output"
:
"[/THINK]This is the rest"
,
"reasoning_content"
:
""
,
"content"
:
"This is the rest"
,
"is_reasoning_end"
:
True
,
}
SHORTEST_REASONING_WITH_THINK
=
{
"output"
:
"[/THINK]This is the rest"
,
"reasoning_content"
:
None
,
"content"
:
"This is the rest"
,
"is_reasoning_end"
:
True
,
}
THINK_NO_END
=
{
"output"
:
"[THINK]This is a reasoning section"
,
"reasoning_content"
:
"This is a reasoning section"
,
"content"
:
None
,
"is_reasoning_end"
:
False
,
}
EMPTY
=
{
"output"
:
""
,
"reasoning_content"
:
""
,
"content"
:
None
,
"is_reasoning_end"
:
False
,
}
EMPTY_STREAMING
=
{
"output"
:
""
,
"reasoning_content"
:
None
,
"content"
:
None
,
"is_reasoning_end"
:
False
,
}
NEW_LINE
=
{
"output"
:
"
\n
[THINK]This is a reasoning section[/THINK]
\n
This is the rest"
,
"reasoning_content"
:
"This is a reasoning section"
,
"content"
:
"
\n
This is the rest"
,
"is_reasoning_end"
:
True
,
}
# Streaming cannot handle new lines at the beginning of the output
# because we need to support [THINK]...[/THINK] and [/THINK]...
# We cannot know if the text before [THINK] is reasoning content
# or not.
NEW_LINE_STREAMING
=
{
"output"
:
"
\n
[THINK]This is a reasoning section[/THINK]
\n
This is the rest"
,
"reasoning_content"
:
"
\n
This is a reasoning section"
,
"content"
:
"
\n
This is the rest"
,
"is_reasoning_end"
:
True
,
}
TEST_CASES
=
[
pytest
.
param
(
False
,
SIMPLE_REASONING
,
id
=
"simple_reasoning"
,
),
pytest
.
param
(
True
,
SIMPLE_REASONING
,
id
=
"simple_reasoning_streaming"
,
),
pytest
.
param
(
False
,
COMPLETE_REASONING
,
id
=
"complete_reasoning"
,
),
pytest
.
param
(
True
,
COMPLETE_REASONING
,
id
=
"complete_reasoning_streaming"
,
),
pytest
.
param
(
False
,
NO_CONTENT
,
id
=
"no_content_token"
,
),
pytest
.
param
(
True
,
NO_REASONING_STREAMING
,
id
=
"no_reasoning_token_streaming"
,
),
pytest
.
param
(
False
,
MULTIPLE_LINES
,
id
=
"multiple_lines"
,
),
pytest
.
param
(
True
,
MULTIPLE_LINES
,
id
=
"multiple_lines_streaming"
,
),
pytest
.
param
(
True
,
SHORTEST_REASONING
,
id
=
"shortest"
,
),
pytest
.
param
(
False
,
SHORTEST_REASONING_NO_STREAMING
,
id
=
"shortest_streaming"
,
),
pytest
.
param
(
False
,
REASONING_WITH_THINK
,
id
=
"reasoning_with_think"
,
),
pytest
.
param
(
True
,
REASONING_WITH_THINK
,
id
=
"reasoning_with_think_streaming"
,
),
pytest
.
param
(
False
,
COMPLETE_REASONING_WITH_THINK
,
id
=
"complete_reasoning_with_think"
,
),
pytest
.
param
(
True
,
COMPLETE_REASONING_WITH_THINK
,
id
=
"complete_reasoning_with_think_streaming"
,
),
pytest
.
param
(
False
,
MULTIPLE_LINES_WITH_THINK
,
id
=
"multiple_lines_with_think"
,
),
pytest
.
param
(
True
,
MULTIPLE_LINES_WITH_THINK
,
id
=
"multiple_lines_with_think_streaming"
,
),
pytest
.
param
(
False
,
SHORTEST_REASONING_NO_STREAMING_WITH_THINK
,
id
=
"shortest_with_think"
,
),
pytest
.
param
(
True
,
SHORTEST_REASONING_WITH_THINK
,
id
=
"shortest_with_think_streaming"
,
),
pytest
.
param
(
False
,
THINK_NO_END
,
id
=
"think_no_end"
,
),
pytest
.
param
(
True
,
THINK_NO_END
,
id
=
"think_no_end_streaming"
,
),
pytest
.
param
(
False
,
EMPTY
,
id
=
"empty"
,
),
pytest
.
param
(
True
,
EMPTY_STREAMING
,
id
=
"empty_streaming"
,
),
pytest
.
param
(
False
,
NEW_LINE
,
id
=
"new_line"
,
),
pytest
.
param
(
True
,
NEW_LINE_STREAMING
,
id
=
"new_line_streaming"
,
),
]
@
pytest
.
mark
.
parametrize
(
"streaming, param_dict"
,
TEST_CASES
)
def
test_mistral_reasoning
(
streaming
:
bool
,
param_dict
:
dict
,
mistral_tokenizer
:
MistralTokenizer
,
):
output
=
param_dict
[
"output"
]
index_think
=
output
.
find
(
"[THINK]"
)
len_think
=
len
(
"[THINK]"
)
index_end_think
=
output
.
find
(
"[/THINK]"
)
len_end_think
=
len
(
"[/THINK]"
)
# encode everything to tokens ids
output_tokens
=
[]
if
index_think
!=
-
1
:
output_before_think
=
output
[:
index_think
]
output_tokens
+=
mistral_tokenizer
.
tokenizer
.
encode
(
output_before_think
,
False
,
False
)
output_tokens
+=
[
mistral_tokenizer
.
instruct
.
BEGIN_THINK
]
if
index_end_think
!=
-
1
:
output_middle
=
output
[
index_think
+
len_think
:
index_end_think
]
output_after_think
=
output
[
index_end_think
+
len_end_think
:]
output_tokens
+=
mistral_tokenizer
.
tokenizer
.
encode
(
output_middle
,
False
,
False
)
output_tokens
+=
[
mistral_tokenizer
.
instruct
.
END_THINK
]
output_tokens
+=
mistral_tokenizer
.
tokenizer
.
encode
(
output_after_think
,
False
,
False
)
else
:
output_middle
=
output
[
index_think
+
len_think
:]
output_tokens
+=
mistral_tokenizer
.
tokenizer
.
encode
(
output_middle
,
False
,
False
)
elif
index_end_think
!=
-
1
:
output_before_think
=
output
[:
index_end_think
]
output_after_think
=
output
[
index_end_think
+
len_end_think
:]
output_tokens
+=
mistral_tokenizer
.
tokenizer
.
encode
(
output_before_think
,
False
,
False
)
output_tokens
+=
[
mistral_tokenizer
.
instruct
.
END_THINK
]
output_tokens
+=
mistral_tokenizer
.
tokenizer
.
encode
(
output_after_think
,
False
,
False
)
else
:
output_tokens
+=
mistral_tokenizer
.
tokenizer
.
encode
(
output
,
False
,
False
)
parser
:
ReasoningParser
=
ReasoningParserManager
.
get_reasoning_parser
(
parser_name
)(
mistral_tokenizer
)
reasoning
,
content
=
run_reasoning_extraction_mistral
(
parser
,
output_tokens
,
streaming
=
streaming
)
assert
reasoning
==
param_dict
[
"reasoning_content"
]
assert
content
==
param_dict
[
"content"
]
# Test is_reasoning_end
is_reasoning_end
=
parser
.
is_reasoning_end
(
output_tokens
)
assert
is_reasoning_end
==
param_dict
[
"is_reasoning_end"
]
# Test extract_content
if
param_dict
[
"content"
]
is
not
None
:
content
=
parser
.
extract_content_ids
(
output_tokens
)
assert
content
==
mistral_tokenizer
.
tokenizer
.
encode
(
param_dict
[
"content"
],
bos
=
False
,
eos
=
False
)
else
:
content
=
parser
.
extract_content_ids
(
output_tokens
)
assert
content
==
[]
tests/reasoning/utils.py
View file @
711aa9d5
...
...
@@ -6,6 +6,7 @@ from typing import Optional, Union
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaMessage
)
from
vllm.reasoning
import
ReasoningParser
from
vllm.transformers_utils.tokenizers.mistral
import
MistralTokenizer
class
StreamingReasoningReconstructor
:
...
...
@@ -54,6 +55,32 @@ def run_reasoning_extraction(
return
reasoning
,
content
def
run_reasoning_extraction_mistral
(
reasoning_parser
:
ReasoningParser
,
model_output
:
list
[
int
],
request
:
Union
[
ChatCompletionRequest
,
None
]
=
None
,
streaming
:
bool
=
False
,
)
->
tuple
[
Optional
[
str
],
Optional
[
str
]]:
assert
isinstance
(
reasoning_parser
.
model_tokenizer
,
MistralTokenizer
),
type
(
reasoning_parser
.
model_tokenizer
)
if
streaming
:
reconstructor
=
run_reasoning_extraction_streaming_mistral
(
reasoning_parser
,
model_output
,
request
,
)
return
(
reconstructor
.
reasoning_content
,
reconstructor
.
other_content
or
None
,
)
else
:
str_output
=
reasoning_parser
.
model_tokenizer
.
convert_ids_to_tokens
(
model_output
)
reasoning
,
content
=
run_reasoning_extraction_nonstreaming
(
reasoning_parser
,
str_output
,
request
)
return
reasoning
,
content
def
run_reasoning_extraction_nonstreaming
(
reasoning_parser
:
ReasoningParser
,
model_output
:
list
[
str
],
...
...
@@ -94,3 +121,35 @@ def run_reasoning_extraction_streaming(
previous_text
=
current_text
previous_tokens
=
current_tokens
return
reconstructor
def
run_reasoning_extraction_streaming_mistral
(
reasoning_parser
:
ReasoningParser
,
model_deltas
:
list
[
int
],
request
:
Union
[
ChatCompletionRequest
,
None
]
=
None
,
)
->
StreamingReasoningReconstructor
:
assert
isinstance
(
reasoning_parser
.
model_tokenizer
,
MistralTokenizer
),
type
(
reasoning_parser
.
model_tokenizer
)
request
=
request
or
ChatCompletionRequest
(
messages
=
[],
model
=
"test-model"
)
reconstructor
=
StreamingReasoningReconstructor
()
previous_text
=
""
previous_tokens
:
list
[
int
]
=
[]
for
model_delta
in
model_deltas
:
token_delta
=
[
model_delta
]
delta
=
reasoning_parser
.
model_tokenizer
.
convert_ids_to_tokens
(
[
model_delta
])[
0
]
current_text
=
previous_text
+
delta
current_tokens
=
previous_tokens
+
token_delta
delta_message
=
reasoning_parser
.
extract_reasoning_content_streaming
(
previous_text
,
current_text
,
delta
,
previous_tokens
,
current_tokens
,
token_delta
,
)
if
delta_message
is
not
None
:
reconstructor
.
append_delta
(
delta_message
)
previous_text
=
current_text
previous_tokens
=
current_tokens
return
reconstructor
tests/samplers/test_ignore_eos.py
View file @
711aa9d5
...
...
@@ -38,7 +38,7 @@ def test_ignore_eos(
ignore_eos
=
True
)
for
prompt
in
example_prompts
:
ignore_eos_output
=
vllm_model
.
model
.
generate
(
ignore_eos_output
=
vllm_model
.
llm
.
generate
(
prompt
,
sampling_params
=
sampling_params
)
output_length
=
len
(
ignore_eos_output
[
0
].
outputs
[
0
].
token_ids
)
assert
output_length
==
max_tokens
tests/samplers/test_logits_processor.py
View file @
711aa9d5
...
...
@@ -28,7 +28,7 @@ def test_logits_processor_force_generate(
dtype
:
str
,
)
->
None
:
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
tokenizer
=
vllm_model
.
model
.
get_tokenizer
()
tokenizer
=
vllm_model
.
llm
.
get_tokenizer
()
repeat_times
=
2
enforced_answers
=
" vLLM"
vllm_token_ids
=
tokenizer
.
encode
(
enforced_answers
,
...
...
@@ -47,13 +47,13 @@ def test_logits_processor_force_generate(
)
# test logits_processors when prompt_logprobs is not None
vllm_model
.
model
.
_add_request
(
vllm_model
.
llm
.
_add_request
(
example_prompts
[
0
],
params
=
params_with_logprobs
,
)
# test prompt_logprobs is not None
vllm_model
.
model
.
_add_request
(
vllm_model
.
llm
.
_add_request
(
example_prompts
[
1
],
params
=
SamplingParams
(
prompt_logprobs
=
3
,
...
...
@@ -62,11 +62,11 @@ def test_logits_processor_force_generate(
)
# test grouped requests
vllm_model
.
model
.
_add_request
(
vllm_model
.
llm
.
_add_request
(
example_prompts
[
2
],
params
=
SamplingParams
(
max_tokens
=
max_tokens
),
)
outputs
=
vllm_model
.
model
.
_run_engine
(
use_tqdm
=
False
)
outputs
=
vllm_model
.
llm
.
_run_engine
(
use_tqdm
=
False
)
assert
outputs
[
0
].
outputs
[
0
].
text
==
enforced_answers
*
repeat_times
tests/samplers/test_logprobs.py
View file @
711aa9d5
...
...
@@ -66,7 +66,7 @@ def test_get_prompt_logprobs(
prompt_logprobs
=
num_top_logprobs
,
temperature
=
0.0
,
detokenize
=
detokenize
)
vllm_results
=
vllm_model
.
model
.
generate
(
vllm_results
=
vllm_model
.
llm
.
generate
(
example_prompts
,
sampling_params
=
vllm_sampling_params
)
# Test whether logprobs are included in the results.
...
...
@@ -176,7 +176,7 @@ def test_none_logprobs(vllm_runner, model, chunked_prefill_token_size: int,
logprobs
=
None
,
temperature
=
0.0
,
detokenize
=
detokenize
)
results_logprobs_none
=
vllm_model
.
model
.
generate
(
results_logprobs_none
=
vllm_model
.
llm
.
generate
(
example_prompts
,
sampling_params
=
sampling_params_logprobs_none
)
for
i
in
range
(
len
(
results_logprobs_none
)):
...
...
Prev
1
…
15
16
17
18
19
20
21
22
23
…
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment