Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
469e903b
Commit
469e903b
authored
Mar 28, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.2' into v0.8.2-dev
parents
389ebcf7
25f560a6
Changes
535
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
147 additions
and
59 deletions
+147
-59
tests/quantization/test_lm_head.py
tests/quantization/test_lm_head.py
+3
-0
tests/quantization/test_quark.py
tests/quantization/test_quark.py
+3
-1
tests/quantization/test_register_quantization_config.py
tests/quantization/test_register_quantization_config.py
+7
-5
tests/samplers/test_beam_search.py
tests/samplers/test_beam_search.py
+8
-0
tests/samplers/test_ignore_eos.py
tests/samplers/test_ignore_eos.py
+7
-0
tests/samplers/test_logits_processor.py
tests/samplers/test_logits_processor.py
+8
-0
tests/samplers/test_logprobs.py
tests/samplers/test_logprobs.py
+10
-3
tests/samplers/test_no_bad_words.py
tests/samplers/test_no_bad_words.py
+15
-8
tests/samplers/test_ranks.py
tests/samplers/test_ranks.py
+6
-0
tests/samplers/test_rejection_sampler.py
tests/samplers/test_rejection_sampler.py
+22
-14
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+30
-22
tests/samplers/test_seeded_generate.py
tests/samplers/test_seeded_generate.py
+3
-1
tests/samplers/test_typical_acceptance_sampler.py
tests/samplers/test_typical_acceptance_sampler.py
+8
-0
tests/spec_decode/conftest.py
tests/spec_decode/conftest.py
+11
-0
tests/spec_decode/e2e/conftest.py
tests/spec_decode/e2e/conftest.py
+6
-5
No files found.
Too many changes to show.
To preserve performance only
535 of 535+
files are displayed.
Plain diff
Email patch
tests/quantization/test_lm_head.py
View file @
469e903b
...
@@ -31,7 +31,10 @@ def test_lm_head(
...
@@ -31,7 +31,10 @@ def test_lm_head(
vllm_runner
,
vllm_runner
,
model_id
:
str
,
model_id
:
str
,
lm_head_quantized
:
bool
,
lm_head_quantized
:
bool
,
monkeypatch
,
)
->
None
:
)
->
None
:
# vllm_runner.apply_model() relies on V0 internals.
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
with
vllm_runner
(
model_id
,
dtype
=
torch
.
float16
,
with
vllm_runner
(
model_id
,
dtype
=
torch
.
float16
,
max_model_len
=
2048
)
as
vllm_model
:
max_model_len
=
2048
)
as
vllm_model
:
...
...
tests/quantization/test_quark.py
View file @
469e903b
...
@@ -10,7 +10,9 @@ from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
...
@@ -10,7 +10,9 @@ from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
QuarkLinearMethod
,
QuarkW8A8Fp8
)
QuarkLinearMethod
,
QuarkW8A8Fp8
)
def
test_quark_fp8
(
vllm_runner
):
def
test_quark_fp8
(
vllm_runner
,
monkeypatch
):
# vllm_runner.apply_model() relies on V0 internals.
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
model_path
=
"amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
model_path
=
"amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
with
vllm_runner
(
model_path
)
as
llm
:
with
vllm_runner
(
model_path
)
as
llm
:
...
...
tests/quantization/test_register_quantization_config.py
View file @
469e903b
...
@@ -5,7 +5,7 @@ See https://github.com/vllm-project/vllm/issues/11926 for more details.
...
@@ -5,7 +5,7 @@ See https://github.com/vllm-project/vllm/issues/11926 for more details.
Run `pytest tests/quantization/test_register_quantization_config.py`.
Run `pytest tests/quantization/test_register_quantization_config.py`.
"""
"""
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Optional
import
pytest
import
pytest
import
torch
import
torch
...
@@ -58,7 +58,7 @@ class CustomQuantConfig(QuantizationConfig):
...
@@ -58,7 +58,7 @@ class CustomQuantConfig(QuantizationConfig):
"""Name of the quantization method."""
"""Name of the quantization method."""
return
"custom_quant"
return
"custom_quant"
def
get_supported_act_dtypes
(
self
)
->
L
ist
[
"torch.dtype"
]:
def
get_supported_act_dtypes
(
self
)
->
l
ist
[
"torch.dtype"
]:
"""List of supported activation dtypes."""
"""List of supported activation dtypes."""
return
[
torch
.
float16
,
torch
.
bfloat16
]
return
[
torch
.
float16
,
torch
.
bfloat16
]
...
@@ -68,12 +68,12 @@ class CustomQuantConfig(QuantizationConfig):
...
@@ -68,12 +68,12 @@ class CustomQuantConfig(QuantizationConfig):
return
-
1
return
-
1
@
staticmethod
@
staticmethod
def
get_config_filenames
()
->
L
ist
[
str
]:
def
get_config_filenames
()
->
l
ist
[
str
]:
"""List of filenames to search for in the model directory."""
"""List of filenames to search for in the model directory."""
return
[]
return
[]
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
:
D
ict
[
str
,
Any
])
->
"CustomQuantConfig"
:
def
from_config
(
cls
,
config
:
d
ict
[
str
,
Any
])
->
"CustomQuantConfig"
:
"""Create a config class from the model's quantization config."""
"""Create a config class from the model's quantization config."""
return
CustomQuantConfig
(
num_bits
=
config
.
get
(
"num_bits"
,
8
))
return
CustomQuantConfig
(
num_bits
=
config
.
get
(
"num_bits"
,
8
))
...
@@ -101,8 +101,10 @@ def test_register_quantization_config():
...
@@ -101,8 +101,10 @@ def test_register_quantization_config():
argvalues
=
[
argvalues
=
[
"meta-llama/Llama-3.2-1B-Instruct"
,
"meta-llama/Llama-3.2-1B-Instruct"
,
])
])
def
test_custom_quant
(
vllm_runner
,
model
):
def
test_custom_quant
(
vllm_runner
,
model
,
monkeypatch
):
"""Test infer with the custom quantization method."""
"""Test infer with the custom quantization method."""
# vllm_runner.apply_model() relies on V0 internals.
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
with
vllm_runner
(
model_name
=
model
,
with
vllm_runner
(
model_name
=
model
,
quantization
=
"custom_quant"
,
quantization
=
"custom_quant"
,
enforce_eager
=
True
)
as
llm
:
enforce_eager
=
True
)
as
llm
:
...
...
tests/samplers/test_beam_search.py
View file @
469e903b
...
@@ -8,6 +8,13 @@ import pytest
...
@@ -8,6 +8,13 @@ import pytest
import
os
import
os
from
..utils
import
models_path_prefix
from
..utils
import
models_path_prefix
@
pytest
.
fixture
(
autouse
=
True
)
def
v1
(
run_with_both_engines
):
"""We can run both engines for this test."""
pass
# FIXME(zhuohan): The test can not pass if we:
# FIXME(zhuohan): The test can not pass if we:
# 1. Increase max_tokens to 256.
# 1. Increase max_tokens to 256.
# 2. Increase beam_width to 8.
# 2. Increase beam_width to 8.
...
@@ -17,6 +24,7 @@ BEAM_WIDTHS = [4]
...
@@ -17,6 +24,7 @@ BEAM_WIDTHS = [4]
MODELS
=
[
os
.
path
.
join
(
models_path_prefix
,
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
)]
MODELS
=
[
os
.
path
.
join
(
models_path_prefix
,
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
)]
@
pytest
.
mark
.
skip_v1
# FIXME: This fails on V1 right now.
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
MAX_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
MAX_TOKENS
)
...
...
tests/samplers/test_ignore_eos.py
View file @
469e903b
...
@@ -10,6 +10,13 @@ import os
...
@@ -10,6 +10,13 @@ import os
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
..utils
import
models_path_prefix
from
..utils
import
models_path_prefix
@
pytest
.
fixture
(
autouse
=
True
)
def
v1
(
run_with_both_engines
):
"""We can run both engines for this test."""
pass
# We also test with llama because it has generation_config to specify EOS
# We also test with llama because it has generation_config to specify EOS
# (past regression).
# (past regression).
MODELS
=
[
os
.
path
.
join
(
models_path_prefix
,
"distilbert/distilgpt2"
),
os
.
path
.
join
(
models_path_prefix
,
"meta-llama/Llama-3.2-1B"
)]
MODELS
=
[
os
.
path
.
join
(
models_path_prefix
,
"distilbert/distilgpt2"
),
os
.
path
.
join
(
models_path_prefix
,
"meta-llama/Llama-3.2-1B"
)]
...
...
tests/samplers/test_logits_processor.py
View file @
469e903b
...
@@ -10,6 +10,14 @@ from ..utils import models_path_prefix
...
@@ -10,6 +10,14 @@ from ..utils import models_path_prefix
MODELS
=
[
os
.
path
.
join
(
models_path_prefix
,
"distilbert/distilgpt2"
)]
MODELS
=
[
os
.
path
.
join
(
models_path_prefix
,
"distilbert/distilgpt2"
)]
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
"""
This file tests V0 internals, so set VLLM_USE_V1=0.
"""
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_logits_processor_force_generate
(
def
test_logits_processor_force_generate
(
...
...
tests/samplers/test_logprobs.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
import
pytest
import
pytest
import
torch
import
torch
import
os
import
os
...
@@ -14,6 +12,15 @@ from ..utils import models_path_prefix
...
@@ -14,6 +12,15 @@ from ..utils import models_path_prefix
MODELS
=
[
os
.
path
.
join
(
models_path_prefix
,
"distilbert/distilgpt2"
)]
MODELS
=
[
os
.
path
.
join
(
models_path_prefix
,
"distilbert/distilgpt2"
)]
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
"""
This module is V0 only since it uses dtype=float, so
set VLLM_USE_V1=0 for all tests in the module.
"""
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
# needed for comparing logprobs with HF
[
"half"
])
# needed for comparing logprobs with HF
...
@@ -72,7 +79,7 @@ def test_get_prompt_logprobs(
...
@@ -72,7 +79,7 @@ def test_get_prompt_logprobs(
assert
(
len
(
logprobs
)
==
num_top_logprobs
assert
(
len
(
logprobs
)
==
num_top_logprobs
or
len
(
logprobs
)
==
num_top_logprobs
+
1
)
or
len
(
logprobs
)
==
num_top_logprobs
+
1
)
output_text
=
result
.
outputs
[
0
].
text
output_text
=
result
.
outputs
[
0
].
text
output_string_from_most_likely_tokens_lst
:
L
ist
[
str
]
=
[]
output_string_from_most_likely_tokens_lst
:
l
ist
[
str
]
=
[]
for
top_logprobs
in
result
.
outputs
[
0
].
logprobs
:
for
top_logprobs
in
result
.
outputs
[
0
].
logprobs
:
top_logprob
=
next
(
iter
(
top_logprobs
.
values
()))
top_logprob
=
next
(
iter
(
top_logprobs
.
values
()))
output_string_from_most_likely_tokens_lst
.
append
(
output_string_from_most_likely_tokens_lst
.
append
(
...
...
tests/samplers/test_no_bad_words.py
View file @
469e903b
...
@@ -5,20 +5,27 @@ Run `pytest tests/samplers/test_no_bad_words.py`.
...
@@ -5,20 +5,27 @@ Run `pytest tests/samplers/test_no_bad_words.py`.
"""
"""
import
os
import
os
from
typing
import
List
,
Optional
from
typing
import
Optional
import
pytest
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
..utils
import
models_path_prefix
from
..utils
import
models_path_prefix
@
pytest
.
fixture
(
autouse
=
True
)
def
v1
(
run_with_both_engines
):
"""We can run both engines for this test."""
pass
def
_generate
(
def
_generate
(
model
:
LLM
,
model
:
LLM
,
prompt
:
str
,
prompt
:
str
,
num_prompt_tokens
:
int
,
num_prompt_tokens
:
int
,
temperature
:
float
=
0
,
temperature
:
float
=
0
,
bad_words
:
Optional
[
L
ist
[
str
]]
=
None
,
bad_words
:
Optional
[
l
ist
[
str
]]
=
None
,
)
->
L
ist
[
int
]:
)
->
l
ist
[
int
]:
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
temperature
=
temperature
,
temperature
=
temperature
,
bad_words
=
bad_words
,
bad_words
=
bad_words
,
...
@@ -60,7 +67,7 @@ class TestOneTokenBadWord:
...
@@ -60,7 +67,7 @@ class TestOneTokenBadWord:
def
_generate
(
self
,
def
_generate
(
self
,
model
:
LLM
,
model
:
LLM
,
bad_words
:
Optional
[
L
ist
[
str
]]
=
None
)
->
L
ist
[
int
]:
bad_words
:
Optional
[
l
ist
[
str
]]
=
None
)
->
l
ist
[
int
]:
return
_generate
(
return
_generate
(
model
=
model
,
model
=
model
,
prompt
=
self
.
PROMPT
,
prompt
=
self
.
PROMPT
,
...
@@ -70,7 +77,7 @@ class TestOneTokenBadWord:
...
@@ -70,7 +77,7 @@ class TestOneTokenBadWord:
def
_encode
(
self
,
def
_encode
(
self
,
prompt
:
str
,
prompt
:
str
,
add_special_tokens
:
bool
=
True
)
->
L
ist
[
int
]:
add_special_tokens
:
bool
=
True
)
->
l
ist
[
int
]:
return
self
.
tokenizer
(
prompt
,
return
self
.
tokenizer
(
prompt
,
add_special_tokens
=
add_special_tokens
).
input_ids
add_special_tokens
=
add_special_tokens
).
input_ids
...
@@ -150,7 +157,7 @@ class TestTwoTokenBadWord:
...
@@ -150,7 +157,7 @@ class TestTwoTokenBadWord:
def
_generate
(
self
,
def
_generate
(
self
,
model
:
LLM
,
model
:
LLM
,
bad_words
:
Optional
[
L
ist
[
str
]]
=
None
)
->
L
ist
[
int
]:
bad_words
:
Optional
[
l
ist
[
str
]]
=
None
)
->
l
ist
[
int
]:
return
_generate
(
return
_generate
(
model
=
model
,
model
=
model
,
prompt
=
self
.
PROMPT
,
prompt
=
self
.
PROMPT
,
...
@@ -159,7 +166,7 @@ class TestTwoTokenBadWord:
...
@@ -159,7 +166,7 @@ class TestTwoTokenBadWord:
)
)
@
staticmethod
@
staticmethod
def
_contains
(
sequence
:
L
ist
[
int
],
subsequence
:
L
ist
[
int
])
->
bool
:
def
_contains
(
sequence
:
l
ist
[
int
],
subsequence
:
l
ist
[
int
])
->
bool
:
searched
=
False
searched
=
False
for
start
in
range
(
len
(
sequence
)):
for
start
in
range
(
len
(
sequence
)):
...
@@ -182,6 +189,6 @@ class TestTwoTokenBadWord:
...
@@ -182,6 +189,6 @@ class TestTwoTokenBadWord:
def
_encode
(
self
,
def
_encode
(
self
,
prompt
:
str
,
prompt
:
str
,
add_special_tokens
:
bool
=
True
)
->
L
ist
[
int
]:
add_special_tokens
:
bool
=
True
)
->
l
ist
[
int
]:
return
self
.
tokenizer
(
prompt
,
return
self
.
tokenizer
(
prompt
,
add_special_tokens
=
add_special_tokens
).
input_ids
add_special_tokens
=
add_special_tokens
).
input_ids
tests/samplers/test_ranks.py
View file @
469e903b
...
@@ -9,6 +9,12 @@ from ..utils import models_path_prefix
...
@@ -9,6 +9,12 @@ from ..utils import models_path_prefix
MODELS
=
[
os
.
path
.
join
(
models_path_prefix
,
"distilbert/distilgpt2"
)]
MODELS
=
[
os
.
path
.
join
(
models_path_prefix
,
"distilbert/distilgpt2"
)]
@
pytest
.
fixture
(
autouse
=
True
)
def
v1
(
run_with_both_engines
):
"""We can run both engines for this test."""
pass
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_ranks
(
def
test_ranks
(
...
...
tests/samplers/test_rejection_sampler.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
"""Tests for rejection sampling."""
"""Tests for rejection sampling."""
from
typing
import
List
,
Tuple
import
pytest
import
pytest
import
torch
import
torch
...
@@ -8,7 +7,16 @@ import torch.nn.functional as F
...
@@ -8,7 +7,16 @@ import torch.nn.functional as F
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.utils
import
is_hip
from
vllm.platforms
import
current_platform
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
"""
This file tests V0 internals, so set VLLM_USE_V1=0.
"""
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
CUDA_DEVICES
=
[
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
...
@@ -46,7 +54,7 @@ def mock_causal_accepted_tensor(
...
@@ -46,7 +54,7 @@ def mock_causal_accepted_tensor(
"which_tokens_accepted"
,
"which_tokens_accepted"
,
[
"all_tokens_accepted"
,
"no_tokens_accepted"
,
"some_tokens_accepted"
])
[
"all_tokens_accepted"
,
"no_tokens_accepted"
,
"some_tokens_accepted"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"use_flashinfer"
,
[
True
,
False
]
if
not
is_hip
()
else
[
False
])
@
pytest
.
mark
.
parametrize
(
"use_flashinfer"
,
[
True
,
False
]
if
not
current_platform
.
is_rocm
()
else
[
False
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_correct_output_format
(
which_tokens_accepted
:
str
,
seed
:
int
,
def
test_correct_output_format
(
which_tokens_accepted
:
str
,
seed
:
int
,
device
:
str
,
use_flashinfer
:
bool
):
device
:
str
,
use_flashinfer
:
bool
):
...
@@ -130,7 +138,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
...
@@ -130,7 +138,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
30_000
,
50_000
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
30_000
,
50_000
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
1
,
32
)))
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
1
,
32
)))
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"use_flashinfer"
,
[
True
,
False
]
if
not
is_hip
()
else
[
False
])
@
pytest
.
mark
.
parametrize
(
"use_flashinfer"
,
[
True
,
False
]
if
not
current_platform
.
is_rocm
()
else
[
False
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_no_crash_with_varying_dims
(
k
:
int
,
vocab_size
:
int
,
batch_size
:
int
,
def
test_no_crash_with_varying_dims
(
k
:
int
,
vocab_size
:
int
,
batch_size
:
int
,
device
:
str
,
use_flashinfer
:
bool
):
device
:
str
,
use_flashinfer
:
bool
):
...
@@ -162,7 +170,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
...
@@ -162,7 +170,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
8
,
32
,
128
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
8
,
32
,
128
])
@
pytest
.
mark
.
parametrize
(
"n_rep"
,
[
100
])
@
pytest
.
mark
.
parametrize
(
"n_rep"
,
[
100
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"use_flashinfer"
,
[
True
,
False
]
if
not
is_hip
()
else
[
False
])
@
pytest
.
mark
.
parametrize
(
"use_flashinfer"
,
[
True
,
False
]
if
not
current_platform
.
is_rocm
()
else
[
False
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_deterministic_when_seeded
(
k
:
int
,
vocab_size
:
int
,
batch_size
:
int
,
def
test_deterministic_when_seeded
(
k
:
int
,
vocab_size
:
int
,
batch_size
:
int
,
frac_seeded
:
float
,
n_rep
:
int
,
device
:
str
,
frac_seeded
:
float
,
n_rep
:
int
,
device
:
str
,
...
@@ -203,7 +211,7 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
...
@@ -203,7 +211,7 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
assert
torch
.
equal
(
results
[
j
][
i
],
results
[
0
][
i
])
assert
torch
.
equal
(
results
[
j
][
i
],
results
[
0
][
i
])
@
pytest
.
mark
.
skipif
(
is_hip
(),
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Consistent with NV."
)
reason
=
"Consistent with NV."
)
@
pytest
.
mark
.
parametrize
(
"k"
,
[
1
,
3
,
6
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
1
,
3
,
6
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
30_000
,
50_000
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
30_000
,
50_000
])
...
@@ -305,7 +313,7 @@ def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
...
@@ -305,7 +313,7 @@ def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
for
i
in
range
(
batch_size
)
for
i
in
range
(
batch_size
)
}
}
for
use_flashinfer
in
[
True
,
False
]
if
not
is_hip
()
else
[
False
]:
for
use_flashinfer
in
[
True
,
False
]
if
not
current_platform
.
is_rocm
()
else
[
False
]:
rejection_sampler
=
RejectionSampler
(
use_flashinfer
=
use_flashinfer
)
rejection_sampler
=
RejectionSampler
(
use_flashinfer
=
use_flashinfer
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
# We use seeded sequences to ensure the same tokens are accepted
# We use seeded sequences to ensure the same tokens are accepted
...
@@ -326,7 +334,7 @@ def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
...
@@ -326,7 +334,7 @@ def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
@
pytest
.
mark
.
parametrize
(
"which_token_ids"
,
@
pytest
.
mark
.
parametrize
(
"which_token_ids"
,
[
"bonus_token_ids"
,
"draft_token_ids"
])
[
"bonus_token_ids"
,
"draft_token_ids"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"use_flashinfer"
,
[
True
,
False
]
if
not
is_hip
()
else
[
False
])
@
pytest
.
mark
.
parametrize
(
"use_flashinfer"
,
[
True
,
False
]
if
not
current_platform
.
is_rocm
()
else
[
False
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_raises_when_vocab_oob
(
above_or_below_vocab_range
:
str
,
def
test_raises_when_vocab_oob
(
above_or_below_vocab_range
:
str
,
which_token_ids
:
str
,
device
:
str
,
which_token_ids
:
str
,
device
:
str
,
...
@@ -378,7 +386,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
...
@@ -378,7 +386,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
@
pytest
.
mark
.
parametrize
(
"draft_and_target_probs_equal"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"draft_and_target_probs_equal"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
5
)))
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
5
)))
@
pytest
.
mark
.
parametrize
(
"use_flashinfer"
,
[
True
,
False
]
if
not
is_hip
()
else
[
False
])
@
pytest
.
mark
.
parametrize
(
"use_flashinfer"
,
[
True
,
False
]
if
not
current_platform
.
is_rocm
()
else
[
False
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_rejection_sampling_approximates_target_distribution
(
def
test_rejection_sampling_approximates_target_distribution
(
seed
:
int
,
draft_and_target_probs_equal
:
bool
,
use_flashinfer
:
bool
):
seed
:
int
,
draft_and_target_probs_equal
:
bool
,
use_flashinfer
:
bool
):
...
@@ -419,8 +427,8 @@ def test_rejection_sampling_approximates_target_distribution(
...
@@ -419,8 +427,8 @@ def test_rejection_sampling_approximates_target_distribution(
draft_and_target_probs_equal
)
draft_and_target_probs_equal
)
sample_sizes
=
[
10
,
100
,
1_000
,
10_000
,
100_000
]
sample_sizes
=
[
10
,
100
,
1_000
,
10_000
,
100_000
]
distance_wrt_reference
:
L
ist
[
float
]
=
[]
distance_wrt_reference
:
l
ist
[
float
]
=
[]
distance_wrt_target
:
L
ist
[
float
]
=
[]
distance_wrt_target
:
l
ist
[
float
]
=
[]
for
num_samples
in
sample_sizes
:
for
num_samples
in
sample_sizes
:
(
reference_vs_rejsample_dist
,
(
reference_vs_rejsample_dist
,
...
@@ -455,7 +463,7 @@ def test_rejection_sampling_approximates_target_distribution(
...
@@ -455,7 +463,7 @@ def test_rejection_sampling_approximates_target_distribution(
expected_improvement_multiplier
)
expected_improvement_multiplier
)
def
get_ratio_first_to_last
(
elements
:
L
ist
[
float
])
->
float
:
def
get_ratio_first_to_last
(
elements
:
l
ist
[
float
])
->
float
:
return
elements
[
0
]
/
elements
[
-
1
]
return
elements
[
0
]
/
elements
[
-
1
]
...
@@ -480,7 +488,7 @@ class _CorrectnessTestHelper:
...
@@ -480,7 +488,7 @@ class _CorrectnessTestHelper:
def
generate_probs_for_test
(
def
generate_probs_for_test
(
self
,
draft_and_target_probs_equal
:
bool
self
,
draft_and_target_probs_equal
:
bool
)
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
draft_probs
,
target_probs
=
(
F
.
softmax
(
draft_probs
,
target_probs
=
(
F
.
softmax
(
torch
.
rand
(
self
.
vocab_size
,
dtype
=
torch
.
float32
),
torch
.
rand
(
self
.
vocab_size
,
dtype
=
torch
.
float32
),
dim
=-
1
,
dim
=-
1
,
...
@@ -502,7 +510,7 @@ class _CorrectnessTestHelper:
...
@@ -502,7 +510,7 @@ class _CorrectnessTestHelper:
def
run_and_compare_distributions
(
self
,
draft_probs
:
torch
.
Tensor
,
def
run_and_compare_distributions
(
self
,
draft_probs
:
torch
.
Tensor
,
target_probs
:
torch
.
Tensor
,
target_probs
:
torch
.
Tensor
,
reference_probs
:
torch
.
Tensor
,
reference_probs
:
torch
.
Tensor
,
num_samples
:
int
)
->
T
uple
[
float
,
float
]:
num_samples
:
int
)
->
t
uple
[
float
,
float
]:
# Sample using rejection sampling.
# Sample using rejection sampling.
rej_sample_probs
=
self
.
_estimate_rejection_sampling_pdf
(
rej_sample_probs
=
self
.
_estimate_rejection_sampling_pdf
(
draft_probs
,
target_probs
,
num_samples
)
draft_probs
,
target_probs
,
num_samples
)
...
...
tests/samplers/test_sampler.py
View file @
469e903b
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
itertools
import
itertools
import
random
import
random
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Optional
from
unittest.mock
import
Mock
,
patch
from
unittest.mock
import
Mock
,
patch
import
pytest
import
pytest
...
@@ -18,6 +18,14 @@ from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
...
@@ -18,6 +18,14 @@ from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from
vllm.utils
import
Counter
,
is_pin_memory_available
from
vllm.utils
import
Counter
,
is_pin_memory_available
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
"""
This file tests V0 internals, so set VLLM_USE_V1=0.
"""
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
class
MockLogitsSampler
(
Sampler
):
class
MockLogitsSampler
(
Sampler
):
def
__init__
(
self
,
fake_logits
:
torch
.
Tensor
):
def
__init__
(
self
,
fake_logits
:
torch
.
Tensor
):
...
@@ -30,7 +38,7 @@ class MockLogitsSampler(Sampler):
...
@@ -30,7 +38,7 @@ class MockLogitsSampler(Sampler):
def
_prepare_test
(
def
_prepare_test
(
batch_size
:
int
batch_size
:
int
)
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
MockLogitsSampler
]:
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
MockLogitsSampler
]:
input_tensor
=
torch
.
rand
((
batch_size
,
1024
),
dtype
=
torch
.
float16
)
input_tensor
=
torch
.
rand
((
batch_size
,
1024
),
dtype
=
torch
.
float16
)
fake_logits
=
torch
.
full
((
batch_size
,
VOCAB_SIZE
),
fake_logits
=
torch
.
full
((
batch_size
,
VOCAB_SIZE
),
1e-2
,
1e-2
,
...
@@ -53,8 +61,8 @@ def _do_sample(
...
@@ -53,8 +61,8 @@ def _do_sample(
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
device
:
str
,
device
:
str
,
):
):
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
]
=
[]
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
]
=
[]
seq_lens
:
L
ist
[
int
]
=
[]
seq_lens
:
l
ist
[
int
]
=
[]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
SequenceGroupMetadata
(
...
@@ -171,7 +179,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -171,7 +179,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def
create_sampling_params
(
min_tokens
,
def
create_sampling_params
(
min_tokens
,
eos_token_id
=
0
,
eos_token_id
=
0
,
*
,
*
,
stop_token_ids
:
Optional
[
L
ist
[
int
]]
=
None
,
stop_token_ids
:
Optional
[
l
ist
[
int
]]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
):
prompt_logprobs
:
Optional
[
int
]
=
None
):
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
min_tokens
=
min_tokens
,
min_tokens
=
min_tokens
,
...
@@ -196,7 +204,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -196,7 +204,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
batch_size
=
random
.
randint
(
1
,
128
)
batch_size
=
random
.
randint
(
1
,
128
)
expected_penalization
=
[]
expected_penalization
=
[]
sequence_metadata_list
:
L
ist
[
SequenceGroupMetadata
]
=
[]
sequence_metadata_list
:
l
ist
[
SequenceGroupMetadata
]
=
[]
# 20% chance to generate seq group metadata list with all prompts
# 20% chance to generate seq group metadata list with all prompts
is_prompt
=
random
.
random
()
<
0.2
is_prompt
=
random
.
random
()
<
0.2
while
batch_size
>
0
:
while
batch_size
>
0
:
...
@@ -216,8 +224,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -216,8 +224,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
eos_token_id
=
eos_token_id
,
eos_token_id
=
eos_token_id
,
stop_token_ids
=
stop_token_ids
)
stop_token_ids
=
stop_token_ids
)
seq_data
:
D
ict
[
int
,
SequenceData
]
=
{}
seq_data
:
d
ict
[
int
,
SequenceData
]
=
{}
seq_group_penalization
:
L
ist
[
bool
]
=
[]
seq_group_penalization
:
l
ist
[
bool
]
=
[]
for
_
in
range
(
num_seqs
):
for
_
in
range
(
num_seqs
):
num_input
=
random
.
randint
(
1
,
100
)
num_input
=
random
.
randint
(
1
,
100
)
num_generated
=
0
if
is_prompt
else
random
.
randint
(
1
,
100
)
num_generated
=
0
if
is_prompt
else
random
.
randint
(
1
,
100
)
...
@@ -376,16 +384,16 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -376,16 +384,16 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
else
:
else
:
test_cases
=
[
generate_test_case
()]
test_cases
=
[
generate_test_case
()]
def
run_test_case
(
*
,
expected_penalization
:
L
ist
[
bool
],
def
run_test_case
(
*
,
expected_penalization
:
l
ist
[
bool
],
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
]):
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
]):
assert
expected_penalization
,
\
assert
expected_penalization
,
\
"Invalid test case, need expected_penalization"
"Invalid test case, need expected_penalization"
assert
seq_group_metadata_list
,
\
assert
seq_group_metadata_list
,
\
"Invalid test case, need seq_group_metadata_list"
"Invalid test case, need seq_group_metadata_list"
batch_size
=
0
batch_size
=
0
seq_lens
:
L
ist
[
int
]
=
[]
seq_lens
:
l
ist
[
int
]
=
[]
sampling_params_per_row
:
L
ist
[
SamplingParams
]
=
[]
sampling_params_per_row
:
l
ist
[
SamplingParams
]
=
[]
for
sgm
in
seq_group_metadata_list
:
for
sgm
in
seq_group_metadata_list
:
sampling_params
=
sgm
.
sampling_params
sampling_params
=
sgm
.
sampling_params
...
@@ -456,11 +464,11 @@ def test_sampler_mixed(seed: int, device: str):
...
@@ -456,11 +464,11 @@ def test_sampler_mixed(seed: int, device: str):
batch_size
=
random
.
randint
(
1
,
256
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
input_tensor
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
]
=
[]
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
]
=
[]
expected_tokens
:
L
ist
[
Optional
[
L
ist
[
int
]]]
=
[]
expected_tokens
:
l
ist
[
Optional
[
l
ist
[
int
]]]
=
[]
seq_lens
:
L
ist
[
int
]
=
[]
seq_lens
:
l
ist
[
int
]
=
[]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
expected
:
Optional
[
L
ist
[
int
]]
=
None
expected
:
Optional
[
l
ist
[
int
]]
=
None
sampling_type
=
random
.
randint
(
0
,
2
)
sampling_type
=
random
.
randint
(
0
,
2
)
if
sampling_type
==
0
:
if
sampling_type
==
0
:
sampling_params
=
SamplingParams
(
temperature
=
0
)
sampling_params
=
SamplingParams
(
temperature
=
0
)
...
@@ -492,7 +500,7 @@ def test_sampler_mixed(seed: int, device: str):
...
@@ -492,7 +500,7 @@ def test_sampler_mixed(seed: int, device: str):
))
))
seq_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
seq_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
generators
:
D
ict
[
str
,
torch
.
Generator
]
=
{}
generators
:
d
ict
[
str
,
torch
.
Generator
]
=
{}
def
test_sampling
():
def
test_sampling
():
sampling_metadata
=
SamplingMetadata
.
prepare
(
sampling_metadata
=
SamplingMetadata
.
prepare
(
...
@@ -587,8 +595,8 @@ def test_sampler_top_k_top_p(seed: int, device: str):
...
@@ -587,8 +595,8 @@ def test_sampler_top_k_top_p(seed: int, device: str):
device
=
device
)
device
=
device
)
assert
len
(
processors
)
==
2
# top_p and top_k
assert
len
(
processors
)
==
2
# top_p and top_k
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
]
=
[]
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
]
=
[]
seq_lens
:
L
ist
[
int
]
=
[]
seq_lens
:
l
ist
[
int
]
=
[]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
SequenceGroupMetadata
(
...
@@ -669,10 +677,10 @@ def test_sampler_repetition_penalty_mixed(device: str):
...
@@ -669,10 +677,10 @@ def test_sampler_repetition_penalty_mixed(device: str):
vocab_size
=
8
vocab_size
=
8
def
test_sampling_params
(
sampling_params
:
L
ist
[
SamplingParams
]):
def
test_sampling_params
(
sampling_params
:
l
ist
[
SamplingParams
]):
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
]
=
[]
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
]
=
[]
seq_lens
:
L
ist
[
int
]
=
[]
seq_lens
:
l
ist
[
int
]
=
[]
for
i
in
range
(
2
):
for
i
in
range
(
2
):
seq_group_metadata_list
.
append
(
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
SequenceGroupMetadata
(
...
...
tests/samplers/test_seeded_generate.py
View file @
469e903b
...
@@ -19,7 +19,9 @@ RANDOM_SEEDS = list(range(5))
...
@@ -19,7 +19,9 @@ RANDOM_SEEDS = list(range(5))
@
pytest
.
fixture
@
pytest
.
fixture
def
vllm_model
(
vllm_runner
):
def
vllm_model
(
vllm_runner
,
monkeypatch
):
# This file relies on V0 internals.
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
with
vllm_runner
(
MODEL
,
dtype
=
"half"
)
as
vllm_model
:
with
vllm_runner
(
MODEL
,
dtype
=
"half"
)
as
vllm_model
:
yield
vllm_model
yield
vllm_model
...
...
tests/samplers/test_typical_acceptance_sampler.py
View file @
469e903b
...
@@ -11,6 +11,14 @@ from vllm.model_executor.utils import set_random_seed
...
@@ -11,6 +11,14 @@ from vllm.model_executor.utils import set_random_seed
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
)]
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
)]
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
"""
This file tests V0 internals, so set VLLM_USE_V1=0.
"""
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
def
get_zero_temperature_prob_dist
(
batch_size
,
k
,
vocab_size
):
def
get_zero_temperature_prob_dist
(
batch_size
,
k
,
vocab_size
):
"""
"""
Generates a fake temperature zero probability distribution.
Generates a fake temperature zero probability distribution.
...
...
tests/spec_decode/conftest.py
0 → 100644
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
import
pytest
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
"""
Since this module is V0 only, set VLLM_USE_V1=0 for
all tests in the module.
"""
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
tests/spec_decode/e2e/conftest.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
collections.abc
import
Sequence
from
itertools
import
cycle
from
itertools
import
cycle
from
typing
import
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
Optional
,
Union
import
pytest
import
pytest
import
torch
import
torch
...
@@ -55,7 +56,7 @@ def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
...
@@ -55,7 +56,7 @@ def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
def
maybe_assert_ngram_worker
(
llm
):
def
maybe_assert_ngram_worker
(
llm
):
# Verify the proposer worker is ngram if ngram is specified.
# Verify the proposer worker is ngram if ngram is specified.
if
(
llm
.
llm_engine
.
speculative_config
is
not
None
if
(
llm
.
llm_engine
.
speculative_config
is
not
None
and
llm
.
llm_engine
.
speculative_config
.
ngram_prompt_lookup_max
>
0
):
and
llm
.
llm_engine
.
speculative_config
.
method
==
"ngram"
):
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.ngram_worker
import
NGramWorker
assert
isinstance
(
assert
isinstance
(
llm
.
llm_engine
.
model_executor
.
driver_worker
.
proposer_worker
,
llm
.
llm_engine
.
model_executor
.
driver_worker
.
proposer_worker
,
...
@@ -64,9 +65,9 @@ def maybe_assert_ngram_worker(llm):
...
@@ -64,9 +65,9 @@ def maybe_assert_ngram_worker(llm):
def
get_output_from_llm_generator
(
def
get_output_from_llm_generator
(
llm_generator
,
prompts
,
llm_generator
,
prompts
,
sampling_params
)
->
T
uple
[
L
ist
[
str
],
L
ist
[
L
ist
[
int
]],
float
]:
sampling_params
)
->
t
uple
[
l
ist
[
str
],
l
ist
[
l
ist
[
int
]],
float
]:
tokens
:
L
ist
[
str
]
=
[]
tokens
:
l
ist
[
str
]
=
[]
token_ids
:
L
ist
[
L
ist
[
int
]]
=
[]
token_ids
:
l
ist
[
l
ist
[
int
]]
=
[]
acceptance_rate
:
float
=
-
1.0
acceptance_rate
:
float
=
-
1.0
for
llm
in
llm_generator
():
for
llm
in
llm_generator
():
maybe_assert_ngram_worker
(
llm
)
maybe_assert_ngram_worker
(
llm
)
...
...
Prev
1
…
23
24
25
26
27
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment