Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e661d594
Commit
e661d594
authored
Aug 12, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1
parents
6b16ea2e
4db5176d
Changes
374
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
569 additions
and
301 deletions
+569
-301
requirements-cpu.txt
requirements-cpu.txt
+2
-3
requirements-cuda.txt
requirements-cuda.txt
+4
-4
requirements-openvino.txt
requirements-openvino.txt
+27
-3
requirements-test.txt
requirements-test.txt
+4
-1
requirements-tpu.txt
requirements-tpu.txt
+1
-1
setup.py
setup.py
+21
-22
tests/basic_correctness/test_cpu_offload.py
tests/basic_correctness/test_cpu_offload.py
+37
-6
tests/conftest.py
tests/conftest.py
+9
-11
tests/core/block/e2e/test_correctness.py
tests/core/block/e2e/test_correctness.py
+1
-1
tests/core/test_scheduler.py
tests/core/test_scheduler.py
+68
-95
tests/distributed/test_basic_distributed_correctness.py
tests/distributed/test_basic_distributed_correctness.py
+33
-17
tests/distributed/test_chunked_prefill_distributed.py
tests/distributed/test_chunked_prefill_distributed.py
+14
-21
tests/distributed/test_multimodal_broadcast.py
tests/distributed/test_multimodal_broadcast.py
+29
-31
tests/distributed/test_parallel_state.py
tests/distributed/test_parallel_state.py
+0
-57
tests/distributed/test_pipeline_parallel.py
tests/distributed/test_pipeline_parallel.py
+61
-16
tests/distributed/test_pipeline_partition.py
tests/distributed/test_pipeline_partition.py
+34
-0
tests/entrypoints/conftest.py
tests/entrypoints/conftest.py
+21
-1
tests/entrypoints/llm/test_guided_generate.py
tests/entrypoints/llm/test_guided_generate.py
+142
-0
tests/entrypoints/openai/test_chat.py
tests/entrypoints/openai/test_chat.py
+32
-8
tests/entrypoints/openai/test_completion.py
tests/entrypoints/openai/test_completion.py
+29
-3
No files found.
requirements-cpu.txt
View file @
e661d594
...
@@ -2,6 +2,5 @@
...
@@ -2,6 +2,5 @@
-r requirements-common.txt
-r requirements-common.txt
# Dependencies for x86_64 CPUs
# Dependencies for x86_64 CPUs
torch == 2.3.1+cpu; platform_machine != "ppc64le"
torch == 2.4.0+cpu; platform_machine != "ppc64le"
torchvision == 0.18.1+cpu; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
requirements-cuda.txt
View file @
e661d594
...
@@ -4,8 +4,8 @@
...
@@ -4,8 +4,8 @@
# Dependencies for NVIDIA GPUs
# Dependencies for NVIDIA GPUs
ray >= 2.9
ray >= 2.9
nvidia-ml-py # for pynvml package
nvidia-ml-py # for pynvml package
torch == 2.
3.1
torch == 2.
4.0
# These must be updated alongside torch
# These must be updated alongside torch
torchvision == 0.1
8.1
# Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
torchvision == 0.1
9
# Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
xformers == 0.0.27 # Requires PyTorch 2.
3.1
xformers == 0.0.27
.post2
# Requires PyTorch 2.
4.0
vllm-flash-attn == 2.
5.9.post
1 # Requires PyTorch 2.
3.1
vllm-flash-attn == 2.
6.
1 # Requires PyTorch 2.
4.0
requirements-openvino.txt
View file @
e661d594
# Common dependencies
# Common dependencies
-r requirements-common.txt
# -r requirements-common.txt
# TODO: remove temporary copy of all common dependencies once Optimum Intel will support Transformers >= 4.43.2
cmake >= 3.21
ninja # For faster builds.
psutil
sentencepiece # Required for LLaMA tokenizer.
numpy < 2.0.0
requests
tqdm
py-cpuinfo
transformers < 4.43
tokenizers >= 0.19.1 # Required for Llama 3.
fastapi
aiohttp
openai
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
pillow # Required for image processing
prometheus_client >= 0.18.0
prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.10.3
outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq
# OpenVINO dependencies
# OpenVINO dependencies
torch >= 2.1.2
torch >= 2.1.2
openvino ~= 2024.3.0.dev
openvino ~= 2024.3.0.dev
openvino-tokenizers[transformers] ~= 2024.3.0.0.dev
optimum-intel[openvino] >= 1.18.1
optimum-intel[openvino] >= 1.18.1
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
requirements-test.txt
View file @
e661d594
# Needed for Ray accelerated DAG tests
-r requirements-adag.txt
# testing
# testing
pytest
pytest
tensorizer>=2.9.0
tensorizer>=2.9.0
...
@@ -14,8 +17,8 @@ peft
...
@@ -14,8 +17,8 @@ peft
requests
requests
ray
ray
sentence-transformers # required for embedding
sentence-transformers # required for embedding
sparseml==1.8.0 # required for compressed-tensors
compressed-tensors==0.4.0 # required for compressed-tensors
compressed-tensors==0.4.0 # required for compressed-tensors
timm # required for internvl test
# Benchmarking
# Benchmarking
aiohttp
aiohttp
...
...
requirements-tpu.txt
View file @
e661d594
...
@@ -4,4 +4,4 @@
...
@@ -4,4 +4,4 @@
# Dependencies for TPU
# Dependencies for TPU
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
# You can install the dependencies in Dockerfile.tpu.
# You can install the dependencies in Dockerfile.tpu.
triton # To avoid import errors
ray
setup.py
View file @
e661d594
...
@@ -188,9 +188,6 @@ class cmake_build_ext(build_ext):
...
@@ -188,9 +188,6 @@ class cmake_build_ext(build_ext):
# match.
# match.
cmake_args
+=
[
'-DVLLM_PYTHON_EXECUTABLE={}'
.
format
(
sys
.
executable
)]
cmake_args
+=
[
'-DVLLM_PYTHON_EXECUTABLE={}'
.
format
(
sys
.
executable
)]
if
_install_punica
():
cmake_args
+=
[
'-DVLLM_INSTALL_PUNICA_KERNELS=ON'
]
#
#
# Setup parallelism and build tool
# Setup parallelism and build tool
#
#
...
@@ -281,8 +278,8 @@ def _build_custom_ops() -> bool:
...
@@ -281,8 +278,8 @@ def _build_custom_ops() -> bool:
return
_is_cuda
()
or
_is_hip
()
or
_is_cpu
()
return
_is_cuda
()
or
_is_hip
()
or
_is_cpu
()
def
_
install_punica
()
->
bool
:
def
_
build_core_ext
()
->
bool
:
return
envs
.
VLLM_INSTALL_PUNICA_KERNELS
return
not
_is_neuron
()
and
not
_is_tpu
()
def
get_hipcc_rocm_version
():
def
get_hipcc_rocm_version
():
...
@@ -388,19 +385,20 @@ def get_version_add(sha: Optional[str] = None) -> str:
...
@@ -388,19 +385,20 @@ def get_version_add(sha: Optional[str] = None) -> str:
version
+=
".torch"
+
torch
.
__version__
[:
5
]
version
+=
".torch"
+
torch
.
__version__
[:
5
]
new_version_content
=
f
"""
new_version_content
=
f
"""
import warnings
import warnings
try:
try:
import vllm.commit_id
import vllm.commit_id
__commit__ = vllm.commit_id.__commit__
__commit__ = vllm.commit_id.__commit__
except Exception as e:
except Exception as e:
warnings.warn(f"Failed to read commit hash:
\\
n + str(e)",
warnings.warn(f"Failed to read commit hash:
\n
+ str(e)",
RuntimeWarning,
RuntimeWarning,
stacklevel=2)
stacklevel=2)
__commit__ = "COMMIT_HASH_PLACEHOLDER"
__commit__ = "COMMIT_HASH_PLACEHOLDER"
__version__ = "0.5.3.post1"
__version__ = "0.5.4"
__dcu_version__ = f'0.5.3.post1+
{
version
}
'
__dcu_version__ = f'0.5.4+
{
version
}
'
"""
"""
with
open
(
add_version_path
,
encoding
=
"utf-8"
,
mode
=
"w"
)
as
file
:
with
open
(
add_version_path
,
encoding
=
"utf-8"
,
mode
=
"w"
)
as
file
:
...
@@ -507,15 +505,15 @@ def get_requirements() -> List[str]:
...
@@ -507,15 +505,15 @@ def get_requirements() -> List[str]:
ext_modules
=
[]
ext_modules
=
[]
if
_build_core_ext
():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._core_C"
))
if
_is_cuda
()
or
_is_hip
():
if
_is_cuda
()
or
_is_hip
():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._moe_C"
))
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._moe_C"
))
if
_build_custom_ops
():
if
_build_custom_ops
():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._C"
))
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._C"
))
if
_install_punica
():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._punica_C"
))
package_data
=
{
package_data
=
{
"vllm"
:
[
"py.typed"
,
"model_executor/layers/fused_moe/configs/*.json"
]
"vllm"
:
[
"py.typed"
,
"model_executor/layers/fused_moe/configs/*.json"
]
}
}
...
@@ -542,6 +540,7 @@ setup(
...
@@ -542,6 +540,7 @@ setup(
"Programming Language :: Python :: 3.9"
,
"Programming Language :: Python :: 3.9"
,
"Programming Language :: Python :: 3.10"
,
"Programming Language :: Python :: 3.10"
,
"Programming Language :: Python :: 3.11"
,
"Programming Language :: Python :: 3.11"
,
"Programming Language :: Python :: 3.12"
,
"License :: OSI Approved :: Apache Software License"
,
"License :: OSI Approved :: Apache Software License"
,
"Topic :: Scientific/Engineering :: Artificial Intelligence"
,
"Topic :: Scientific/Engineering :: Artificial Intelligence"
,
],
],
...
@@ -553,7 +552,7 @@ setup(
...
@@ -553,7 +552,7 @@ setup(
extras_require
=
{
extras_require
=
{
"tensorizer"
:
[
"tensorizer>=2.9.0"
],
"tensorizer"
:
[
"tensorizer>=2.9.0"
],
},
},
cmdclass
=
{
"build_ext"
:
cmake_build_ext
}
if
_build_custom_ops
()
else
{},
cmdclass
=
{
"build_ext"
:
cmake_build_ext
}
if
len
(
ext_modules
)
>
0
else
{},
package_data
=
package_data
,
package_data
=
package_data
,
entry_points
=
{
entry_points
=
{
"console_scripts"
:
[
"console_scripts"
:
[
...
...
tests/basic_correctness/test_cpu_offload.py
View file @
e661d594
from
vllm.utils
import
is_hip
import
pytest
from
tests.quantization.utils
import
is_quant_method_supported
from
..utils
import
compare_two_settings
from
..utils
import
compare_two_settings
...
@@ -6,8 +8,37 @@ from ..utils import compare_two_settings
...
@@ -6,8 +8,37 @@ from ..utils import compare_two_settings
def
test_cpu_offload
():
def
test_cpu_offload
():
compare_two_settings
(
"meta-llama/Llama-2-7b-hf"
,
[],
compare_two_settings
(
"meta-llama/Llama-2-7b-hf"
,
[],
[
"--cpu-offload-gb"
,
"4"
])
[
"--cpu-offload-gb"
,
"4"
])
if
not
is_hip
():
# compressed-tensors quantization is currently not supported in ROCm.
compare_two_settings
(
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
,
[],
reason
=
"fp8 is not supported on this GPU type."
)
[
"--cpu-offload-gb"
,
"1"
])
def
test_cpu_offload_fp8
():
# Test quantization of an unquantized checkpoint
compare_two_settings
(
"meta-llama/Meta-Llama-3-8B-Instruct"
,
[
"--quantization"
,
"fp8"
],
[
"--quantization"
,
"fp8"
,
"--cpu-offload-gb"
,
"2"
])
# Test loading a quantized checkpoint
compare_two_settings
(
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
,
[],
[
"--cpu-offload-gb"
,
"2"
])
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"awq"
),
reason
=
"awq is not supported on this GPU type."
)
def
test_cpu_offload_awq
():
compare_two_settings
(
"casperhansen/llama-3-8b-instruct-awq"
,
[],
[
"--cpu-offload-gb"
,
"2"
])
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
reason
=
"gptq_marlin is not supported on this GPU type."
)
def
test_cpu_offload_compressed_tensors
():
# Test wNa16
compare_two_settings
(
"nm-testing/tinyllama-oneshot-w4a16-channel-v2"
,
[],
[
"--cpu-offload-gb"
,
"1"
])
# Test w4a16_marlin24
compare_two_settings
(
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
,
[],
[
"--cpu-offload-gb"
,
"1"
])
# Test w8a8
compare_two_settings
(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
,
[],
[
"--cpu-offload-gb"
,
"1"
])
tests/conftest.py
View file @
e661d594
...
@@ -3,7 +3,7 @@ import gc
...
@@ -3,7 +3,7 @@ import gc
import
os
import
os
import
sys
import
sys
from
collections
import
UserList
from
collections
import
UserList
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
TypedDict
,
TypeVar
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
TypedDict
,
TypeVar
,
Union
import
pytest
import
pytest
import
torch
import
torch
...
@@ -11,7 +11,7 @@ import torch.nn as nn
...
@@ -11,7 +11,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
(
AutoModelForCausalLM
,
AutoModelForVision2Seq
,
from
transformers
import
(
AutoModelForCausalLM
,
AutoModelForVision2Seq
,
AutoTokenizer
,
BatchEncoding
)
AutoTokenizer
,
BatchEncoding
,
BatchFeature
)
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.assets.image
import
ImageAsset
from
vllm.assets.image
import
ImageAsset
...
@@ -133,7 +133,7 @@ def image_assets() -> _ImageAssets:
...
@@ -133,7 +133,7 @@ def image_assets() -> _ImageAssets:
return
IMAGE_ASSETS
return
IMAGE_ASSETS
_T
=
TypeVar
(
"_T"
,
nn
.
Module
,
torch
.
Tensor
,
BatchEncoding
)
_T
=
TypeVar
(
"_T"
,
nn
.
Module
,
torch
.
Tensor
,
BatchEncoding
,
BatchFeature
)
class
HfRunner
:
class
HfRunner
:
...
@@ -152,7 +152,6 @@ class HfRunner:
...
@@ -152,7 +152,6 @@ class HfRunner:
model_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
model_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
is_embedding_model
:
bool
=
False
,
is_embedding_model
:
bool
=
False
,
is_vision_model
:
bool
=
False
,
is_vision_model
:
bool
=
False
,
is_sparseml_model
:
bool
=
False
,
)
->
None
:
)
->
None
:
torch_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
torch_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
...
@@ -169,9 +168,6 @@ class HfRunner:
...
@@ -169,9 +168,6 @@ class HfRunner:
else
:
else
:
if
is_vision_model
:
if
is_vision_model
:
auto_cls
=
AutoModelForVision2Seq
auto_cls
=
AutoModelForVision2Seq
elif
is_sparseml_model
:
from
sparseml.transformers
import
SparseAutoModelForCausalLM
auto_cls
=
SparseAutoModelForCausalLM
else
:
else
:
auto_cls
=
AutoModelForCausalLM
auto_cls
=
AutoModelForCausalLM
...
@@ -339,7 +335,6 @@ class HfRunner:
...
@@ -339,7 +335,6 @@ class HfRunner:
processor_kwargs
[
"images"
]
=
images
[
i
]
processor_kwargs
[
"images"
]
=
images
[
i
]
inputs
=
self
.
processor
(
**
processor_kwargs
)
inputs
=
self
.
processor
(
**
processor_kwargs
)
input_ids
=
inputs
.
input_ids
output
=
self
.
model
.
generate
(
output
=
self
.
model
.
generate
(
**
self
.
wrap_device
(
inputs
),
**
self
.
wrap_device
(
inputs
),
...
@@ -381,7 +376,7 @@ class HfRunner:
...
@@ -381,7 +376,7 @@ class HfRunner:
all_logprobs
.
append
(
seq_logprobs_lst
)
all_logprobs
.
append
(
seq_logprobs_lst
)
seq_ids
=
output
.
sequences
[
0
]
seq_ids
=
output
.
sequences
[
0
]
output_len
=
seq_ids
.
shape
[
0
]
-
input_ids
.
shape
[
1
]
output_len
=
len
(
seq_logprobs_lst
)
output_ids
=
seq_ids
[
-
output_len
:]
output_ids
=
seq_ids
[
-
output_len
:]
all_output_ids
.
append
(
output_ids
.
tolist
())
all_output_ids
.
append
(
output_ids
.
tolist
())
all_output_strs
.
append
(
self
.
tokenizer
.
decode
(
output_ids
))
all_output_strs
.
append
(
self
.
tokenizer
.
decode
(
output_ids
))
...
@@ -513,11 +508,14 @@ class VllmRunner:
...
@@ -513,11 +508,14 @@ class VllmRunner:
prompts
:
List
[
str
],
prompts
:
List
[
str
],
max_tokens
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
int
,
images
:
Optional
[
List
[
Image
.
Image
]]
=
None
,
images
:
Optional
[
Union
[
List
[
Image
.
Image
],
List
[
List
[
Image
.
Image
]]]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
greedy_logprobs_params
=
SamplingParams
(
temperature
=
0.0
,
greedy_logprobs_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
logprobs
=
num_logprobs
)
logprobs
=
num_logprobs
,
stop_token_ids
=
stop_token_ids
)
outputs
=
self
.
generate_w_logprobs
(
prompts
,
outputs
=
self
.
generate_w_logprobs
(
prompts
,
greedy_logprobs_params
,
greedy_logprobs_params
,
images
=
images
)
images
=
images
)
...
...
tests/core/block/e2e/test_correctness.py
View file @
e661d594
...
@@ -183,7 +183,7 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
...
@@ -183,7 +183,7 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
# Allow only 2 sequences of ~128 tokens in worst case.
# Allow only 2 sequences of ~128 tokens in worst case.
# Note 16 = 128/block_size
# Note 16 = 128/block_size
"num_gpu_blocks_override"
:
2
*
(
16
+
1
),
"num_gpu_blocks_override"
:
2
*
(
16
+
2
),
}
}
])
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{
...
...
tests/core/test_scheduler.py
View file @
e661d594
import
time
import
time
from
collections
import
deque
from
collections
import
deque
from
typing
import
Deque
,
List
,
Set
,
Tuple
from
typing
import
List
,
Set
,
Tuple
from
unittest.mock
import
MagicMock
from
unittest.mock
import
MagicMock
import
pytest
# noqa
import
pytest
# noqa
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.core.interfaces
import
AllocStatus
from
vllm.core.interfaces
import
AllocStatus
from
vllm.core.policy
import
PolicyFactory
from
vllm.core.scheduler
import
Scheduler
,
SchedulingBudget
from
vllm.core.scheduler
import
Scheduler
,
SchedulingBudget
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
Logprob
,
SequenceGroup
,
SequenceStatus
from
vllm.sequence
import
Logprob
,
SequenceGroup
,
SequenceStatus
...
@@ -348,10 +347,10 @@ def test_prefill_schedule_max_prompt_len():
...
@@ -348,10 +347,10 @@ def test_prefill_schedule_max_prompt_len():
"""
"""
scheduler
=
initialize_scheduler
(
max_model_len
=
30
)
scheduler
=
initialize_scheduler
(
max_model_len
=
30
)
_
,
seq_group
=
create_dummy_prompt
(
"0"
,
prompt_length
=
60
)
_
,
seq_group
=
create_dummy_prompt
(
"0"
,
prompt_length
=
60
)
waiting
=
deque
([
seq_group
]
)
scheduler
.
add_seq_group
(
seq_group
)
budget
=
create_token_budget
()
budget
=
create_token_budget
()
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
waiting
,
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
1
assert
len
(
output
.
ignored_seq_groups
)
==
1
assert
len
(
output
.
seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_batched_tokens
==
0
...
@@ -364,15 +363,14 @@ def test_prefill_schedule_token_budget():
...
@@ -364,15 +363,14 @@ def test_prefill_schedule_token_budget():
Test token budget respected.
Test token budget respected.
"""
"""
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
waiting
:
Deque
[
SequenceGroup
]
=
deque
()
budget
=
create_token_budget
(
token_budget
=
0
)
budget
=
create_token_budget
(
token_budget
=
0
)
for
i
in
range
(
2
):
for
i
in
range
(
2
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
# 0 token budget == nothing is scheduled.
# 0 token budget == nothing is scheduled.
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
waiting
,
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_batched_tokens
==
0
...
@@ -381,8 +379,8 @@ def test_prefill_schedule_token_budget():
...
@@ -381,8 +379,8 @@ def test_prefill_schedule_token_budget():
# 60 token budget == 1 request scheduled.
# 60 token budget == 1 request scheduled.
budget
=
create_token_budget
(
token_budget
=
60
)
budget
=
create_token_budget
(
token_budget
=
60
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
waiting
,
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
1
assert
len
(
output
.
seq_groups
)
==
1
assert
budget
.
num_batched_tokens
==
60
assert
budget
.
num_batched_tokens
==
60
...
@@ -391,14 +389,13 @@ def test_prefill_schedule_token_budget():
...
@@ -391,14 +389,13 @@ def test_prefill_schedule_token_budget():
# Test when current_batched_tokens respected.
# Test when current_batched_tokens respected.
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
waiting
=
deque
()
budget
=
create_token_budget
(
token_budget
=
60
)
budget
=
create_token_budget
(
token_budget
=
60
)
add_token_budget
(
budget
,
30
,
0
)
add_token_budget
(
budget
,
30
,
0
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
# Cannot schedule a prompt that doesn't fit the budget.
# Cannot schedule a prompt that doesn't fit the budget.
waiting
.
append
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
waiting
,
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
30
assert
budget
.
num_batched_tokens
==
30
...
@@ -406,8 +403,8 @@ def test_prefill_schedule_token_budget():
...
@@ -406,8 +403,8 @@ def test_prefill_schedule_token_budget():
assert
len
(
remaining_waiting
)
==
1
assert
len
(
remaining_waiting
)
==
1
budget
=
create_token_budget
(
token_budget
=
90
)
budget
=
create_token_budget
(
token_budget
=
90
)
add_token_budget
(
budget
,
30
,
0
)
add_token_budget
(
budget
,
30
,
0
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
waiting
,
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
seq_groups
)
==
1
assert
len
(
output
.
seq_groups
)
==
1
assert
budget
.
num_batched_tokens
==
90
assert
budget
.
num_batched_tokens
==
90
assert
budget
.
num_curr_seqs
==
1
assert
budget
.
num_curr_seqs
==
1
...
@@ -419,13 +416,12 @@ def test_prefill_schedule_max_seqs():
...
@@ -419,13 +416,12 @@ def test_prefill_schedule_max_seqs():
Test max seq respected.
Test max seq respected.
"""
"""
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
waiting
:
Deque
[
SequenceGroup
]
=
deque
()
budget
=
create_token_budget
(
max_num_seqs
=
2
)
budget
=
create_token_budget
(
max_num_seqs
=
2
)
for
i
in
range
(
3
):
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
waiting
,
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
2
assert
len
(
output
.
seq_groups
)
==
2
assert
budget
.
num_batched_tokens
==
120
assert
budget
.
num_batched_tokens
==
120
...
@@ -433,13 +429,13 @@ def test_prefill_schedule_max_seqs():
...
@@ -433,13 +429,13 @@ def test_prefill_schedule_max_seqs():
assert
len
(
remaining_waiting
)
==
1
assert
len
(
remaining_waiting
)
==
1
# Verify curr_num_seqs respected.
# Verify curr_num_seqs respected.
waiting
=
deque
()
scheduler
.
waiting
=
deque
()
budget
=
create_token_budget
(
max_num_seqs
=
2
)
budget
=
create_token_budget
(
max_num_seqs
=
2
)
add_token_budget
(
budget
,
0
,
2
)
add_token_budget
(
budget
,
0
,
2
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
waiting
,
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_batched_tokens
==
0
...
@@ -453,7 +449,6 @@ def test_prefill_schedule_max_lora():
...
@@ -453,7 +449,6 @@ def test_prefill_schedule_max_lora():
"""
"""
lora_config
=
LoRAConfig
(
max_lora_rank
=
8
,
max_loras
=
1
)
lora_config
=
LoRAConfig
(
max_lora_rank
=
8
,
max_loras
=
1
)
scheduler
=
initialize_scheduler
(
lora_config
=
lora_config
)
scheduler
=
initialize_scheduler
(
lora_config
=
lora_config
)
waiting
:
Deque
[
SequenceGroup
]
=
deque
()
budget
=
create_token_budget
(
token_budget
=
120
)
budget
=
create_token_budget
(
token_budget
=
120
)
curr_loras
:
Set
[
int
]
=
set
()
curr_loras
:
Set
[
int
]
=
set
()
for
i
in
range
(
2
):
for
i
in
range
(
2
):
...
@@ -463,7 +458,7 @@ def test_prefill_schedule_max_lora():
...
@@ -463,7 +458,7 @@ def test_prefill_schedule_max_lora():
lora_name
=
str
(
i
),
lora_name
=
str
(
i
),
lora_int_id
=
i
+
1
,
lora_int_id
=
i
+
1
,
lora_path
=
"abc"
))
lora_path
=
"abc"
))
waiting
.
append
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
# Add two more requests to verify lora is prioritized.
# Add two more requests to verify lora is prioritized.
# 0: Lora, 1: Lora, 2: regular, 3: regular
# 0: Lora, 1: Lora, 2: regular, 3: regular
# In the first iteration, index 0, 2 is scheduled.
# In the first iteration, index 0, 2 is scheduled.
...
@@ -471,10 +466,10 @@ def test_prefill_schedule_max_lora():
...
@@ -471,10 +466,10 @@ def test_prefill_schedule_max_lora():
# prioritized. Verify that.
# prioritized. Verify that.
for
i
in
range
(
2
,
4
):
for
i
in
range
(
2
,
4
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
# Schedule 2 requests (0 and 2)
# Schedule 2 requests (0 and 2)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
output
=
scheduler
.
_schedule_prefills
(
budget
,
curr_loras
)
waiting
,
budget
,
curr_loras
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
2
assert
len
(
output
.
seq_groups
)
==
2
assert
budget
.
num_batched_tokens
==
120
assert
budget
.
num_batched_tokens
==
120
...
@@ -485,8 +480,8 @@ def test_prefill_schedule_max_lora():
...
@@ -485,8 +480,8 @@ def test_prefill_schedule_max_lora():
# Reset curr_loras so that it can be scheduled.
# Reset curr_loras so that it can be scheduled.
curr_loras
=
set
()
curr_loras
=
set
()
budget
=
create_token_budget
(
token_budget
=
60
)
budget
=
create_token_budget
(
token_budget
=
60
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
output
=
scheduler
.
_schedule_prefills
(
budget
,
curr_loras
)
remaining_waiting
,
budget
,
curr_loras
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
seq_groups
)
==
1
assert
len
(
output
.
seq_groups
)
==
1
assert
output
.
seq_groups
[
0
].
seq_group
.
request_id
==
"1"
assert
output
.
seq_groups
[
0
].
seq_group
.
request_id
==
"1"
assert
len
(
remaining_waiting
)
==
1
assert
len
(
remaining_waiting
)
==
1
...
@@ -499,31 +494,29 @@ def test_prefill_schedule_no_block_manager_capacity():
...
@@ -499,31 +494,29 @@ def test_prefill_schedule_no_block_manager_capacity():
Test sequence cannot be scheduled due to block manager has no capacity.
Test sequence cannot be scheduled due to block manager has no capacity.
"""
"""
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
waiting
:
Deque
[
SequenceGroup
]
=
deque
()
budget
=
create_token_budget
()
budget
=
create_token_budget
()
for
i
in
range
(
3
):
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
scheduler
.
block_manager
.
can_allocate
=
MagicMock
()
scheduler
.
block_manager
.
can_allocate
=
MagicMock
()
scheduler
.
block_manager
.
can_allocate
.
return_value
=
AllocStatus
.
LATER
scheduler
.
block_manager
.
can_allocate
.
return_value
=
AllocStatus
.
LATER
remainig_waiting
,
output
=
scheduler
.
_schedule_prefills
(
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
waiting
,
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_curr_seqs
==
0
assert
budget
.
num_curr_seqs
==
0
assert
len
(
remainig_waiting
)
==
3
assert
len
(
remaini
n
g_waiting
)
==
3
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
waiting
=
deque
()
budget
=
create_token_budget
()
budget
=
create_token_budget
()
for
i
in
range
(
3
):
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
scheduler
.
block_manager
.
can_allocate
=
MagicMock
()
scheduler
.
block_manager
.
can_allocate
=
MagicMock
()
scheduler
.
block_manager
.
can_allocate
.
return_value
=
AllocStatus
.
NEVER
scheduler
.
block_manager
.
can_allocate
.
return_value
=
AllocStatus
.
NEVER
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
waiting
,
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
3
assert
len
(
output
.
ignored_seq_groups
)
==
3
assert
len
(
output
.
seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_batched_tokens
==
0
...
@@ -536,14 +529,12 @@ def test_decode_schedule_preempted():
...
@@ -536,14 +529,12 @@ def test_decode_schedule_preempted():
Test decodes cannot be scheduled and preempted.
Test decodes cannot be scheduled and preempted.
"""
"""
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
running
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
for
i
in
range
(
3
):
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
running
.
append
(
seq_group
)
scheduler
.
_add_seq_group_to_
running
(
seq_group
)
scheduler
.
block_manager
.
can_append_slots
=
MagicMock
()
scheduler
.
block_manager
.
can_append_slots
=
MagicMock
()
def
cannot_append_second_group
(
seq_group
,
num_lookahead_slots
):
def
cannot_append_second_group
(
seq_group
,
num_lookahead_slots
):
...
@@ -555,8 +546,8 @@ def test_decode_schedule_preempted():
...
@@ -555,8 +546,8 @@ def test_decode_schedule_preempted():
# 1 cannot be scheduled, and the lowest priority (request 2)
# 1 cannot be scheduled, and the lowest priority (request 2)
# should be preempted. 1 will also be preempted.
# should be preempted. 1 will also be preempted.
budget
=
create_token_budget
()
budget
=
create_token_budget
()
remainig_running
,
output
=
scheduler
.
_schedule_running
(
output
=
scheduler
.
_schedule_running
(
budget
,
curr_loras
)
running
,
budget
,
curr_loras
,
policy
)
remainig_running
=
scheduler
.
running
assert
len
(
remainig_running
)
==
0
assert
len
(
remainig_running
)
==
0
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
prefill_seq_groups
)
==
0
assert
len
(
output
.
prefill_seq_groups
)
==
0
...
@@ -577,14 +568,12 @@ def test_decode_swap_beam_search():
...
@@ -577,14 +568,12 @@ def test_decode_swap_beam_search():
Test best_of > 1 swap out blocks
Test best_of > 1 swap out blocks
"""
"""
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
running
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
budget
=
create_token_budget
()
budget
=
create_token_budget
()
for
i
in
range
(
3
):
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
,
best_of
=
2
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
running
.
append
(
seq_group
)
scheduler
.
_add_seq_group_to_
running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
budget
.
add_num_seqs
(
seq_group
.
request_id
,
budget
.
add_num_seqs
(
seq_group
.
request_id
,
seq_group
.
get_max_num_running_seqs
())
seq_group
.
get_max_num_running_seqs
())
...
@@ -603,8 +592,8 @@ def test_decode_swap_beam_search():
...
@@ -603,8 +592,8 @@ def test_decode_swap_beam_search():
expected_swap_mapping
=
[(
"5"
,
"7"
)]
expected_swap_mapping
=
[(
"5"
,
"7"
)]
scheduler
.
block_manager
.
swap_out
.
return_value
=
expected_swap_mapping
scheduler
.
block_manager
.
swap_out
.
return_value
=
expected_swap_mapping
remainig_running
,
output
=
scheduler
.
_schedule_running
(
output
=
scheduler
.
_schedule_running
(
budget
,
curr_loras
)
running
,
budget
,
curr_loras
,
policy
)
remainig_running
=
scheduler
.
running
assert
len
(
remainig_running
)
==
0
assert
len
(
remainig_running
)
==
0
assert
len
(
output
.
decode_seq_groups
)
==
2
assert
len
(
output
.
decode_seq_groups
)
==
2
assert
len
(
output
.
prefill_seq_groups
)
==
0
assert
len
(
output
.
prefill_seq_groups
)
==
0
...
@@ -628,20 +617,18 @@ def test_schedule_decode_blocks_to_copy_update():
...
@@ -628,20 +617,18 @@ def test_schedule_decode_blocks_to_copy_update():
"""
"""
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
running
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
running
.
append
(
seq_group
)
scheduler
.
_add_seq_group_to_
running
(
seq_group
)
# The last request should be swapped out.
# The last request should be swapped out.
scheduler
.
block_manager
.
append_slots
=
MagicMock
()
scheduler
.
block_manager
.
append_slots
=
MagicMock
()
scheduler
.
block_manager
.
append_slots
.
return_value
=
[(
2
,
3
)]
scheduler
.
block_manager
.
append_slots
.
return_value
=
[(
2
,
3
)]
budget
=
create_token_budget
()
budget
=
create_token_budget
()
remaining_running
,
output
=
scheduler
.
_schedule_running
(
output
=
scheduler
.
_schedule_running
(
budget
,
curr_loras
)
running
,
budget
,
curr_loras
,
policy
)
remaining_running
=
scheduler
.
running
assert
len
(
remaining_running
)
==
0
assert
len
(
remaining_running
)
==
0
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
prefill_seq_groups
)
==
0
assert
len
(
output
.
prefill_seq_groups
)
==
0
...
@@ -656,19 +643,17 @@ def test_schedule_decode_blocks_to_copy_update():
...
@@ -656,19 +643,17 @@ def test_schedule_decode_blocks_to_copy_update():
def
test_schedule_swapped_simple
():
def
test_schedule_swapped_simple
():
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
s
wapped
.
appe
n
d
(
seq_group
)
s
cheduler
.
_add_seq_group_to_sw
apped
(
seq_group
)
budget
=
create_token_budget
()
budget
=
create_token_budget
()
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
swapped
,
budget
,
curr_loras
,
policy
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
0
assert
len
(
remaining_swapped
)
==
0
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_curr_seqs
==
2
assert
budget
.
num_curr_seqs
==
2
...
@@ -683,8 +668,6 @@ def test_schedule_swapped_simple():
...
@@ -683,8 +668,6 @@ def test_schedule_swapped_simple():
def
test_schedule_swapped_max_token_budget
():
def
test_schedule_swapped_max_token_budget
():
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
_
in
range
(
2
):
for
_
in
range
(
2
):
...
@@ -692,11 +675,11 @@ def test_schedule_swapped_max_token_budget():
...
@@ -692,11 +675,11 @@ def test_schedule_swapped_max_token_budget():
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
s
wapped
.
appe
n
d
(
seq_group
)
s
cheduler
.
_add_seq_group_to_sw
apped
(
seq_group
)
budget
=
create_token_budget
(
token_budget
=
1
)
budget
=
create_token_budget
(
token_budget
=
1
)
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
swapped
,
budget
,
curr_loras
,
policy
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
1
assert
len
(
remaining_swapped
)
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_curr_seqs
==
2
assert
budget
.
num_curr_seqs
==
2
...
@@ -706,8 +689,8 @@ def test_schedule_swapped_max_token_budget():
...
@@ -706,8 +689,8 @@ def test_schedule_swapped_max_token_budget():
# Verify num_batched_tokens are respected.
# Verify num_batched_tokens are respected.
budget
=
create_token_budget
(
token_budget
=
1
)
budget
=
create_token_budget
(
token_budget
=
1
)
add_token_budget
(
budget
,
1
,
0
)
add_token_budget
(
budget
,
1
,
0
)
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
remaining_swapped
,
budget
,
curr_loras
,
policy
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
1
assert
len
(
remaining_swapped
)
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_curr_seqs
==
0
assert
budget
.
num_curr_seqs
==
0
...
@@ -717,8 +700,6 @@ def test_schedule_swapped_max_token_budget():
...
@@ -717,8 +700,6 @@ def test_schedule_swapped_max_token_budget():
def
test_schedule_swapped_max_seqs
():
def
test_schedule_swapped_max_seqs
():
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
i
in
range
(
4
):
for
i
in
range
(
4
):
...
@@ -726,11 +707,11 @@ def test_schedule_swapped_max_seqs():
...
@@ -726,11 +707,11 @@ def test_schedule_swapped_max_seqs():
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
s
wapped
.
appe
n
d
(
seq_group
)
s
cheduler
.
_add_seq_group_to_sw
apped
(
seq_group
)
budget
=
create_token_budget
(
max_num_seqs
=
2
)
budget
=
create_token_budget
(
max_num_seqs
=
2
)
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
swapped
,
budget
,
curr_loras
,
policy
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
2
assert
len
(
remaining_swapped
)
==
2
assert
budget
.
num_batched_tokens
==
2
assert
budget
.
num_batched_tokens
==
2
assert
budget
.
num_curr_seqs
==
2
assert
budget
.
num_curr_seqs
==
2
...
@@ -738,8 +719,8 @@ def test_schedule_swapped_max_seqs():
...
@@ -738,8 +719,8 @@ def test_schedule_swapped_max_seqs():
assert
len
(
output
.
prefill_seq_groups
)
==
0
assert
len
(
output
.
prefill_seq_groups
)
==
0
# Verify num_curr_seqs are respected.
# Verify num_curr_seqs are respected.
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
remaining_swapped
,
budget
,
curr_loras
,
policy
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
2
assert
len
(
remaining_swapped
)
==
2
assert
budget
.
num_batched_tokens
==
2
assert
budget
.
num_batched_tokens
==
2
assert
budget
.
num_curr_seqs
==
2
assert
budget
.
num_curr_seqs
==
2
...
@@ -750,8 +731,6 @@ def test_schedule_swapped_max_seqs():
...
@@ -750,8 +731,6 @@ def test_schedule_swapped_max_seqs():
def
test_schedule_swapped_max_loras
():
def
test_schedule_swapped_max_loras
():
lora_config
=
LoRAConfig
(
max_lora_rank
=
8
,
max_loras
=
1
)
lora_config
=
LoRAConfig
(
max_lora_rank
=
8
,
max_loras
=
1
)
scheduler
=
initialize_scheduler
(
lora_config
=
lora_config
)
scheduler
=
initialize_scheduler
(
lora_config
=
lora_config
)
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
:
Set
[
int
]
=
set
()
curr_loras
:
Set
[
int
]
=
set
()
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
i
in
range
(
2
):
for
i
in
range
(
2
):
...
@@ -764,11 +743,11 @@ def test_schedule_swapped_max_loras():
...
@@ -764,11 +743,11 @@ def test_schedule_swapped_max_loras():
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
s
wapped
.
appe
n
d
(
seq_group
)
s
cheduler
.
_add_seq_group_to_sw
apped
(
seq_group
)
budget
=
create_token_budget
()
budget
=
create_token_budget
()
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
swapped
,
budget
,
curr_loras
,
policy
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
1
assert
len
(
remaining_swapped
)
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_curr_seqs
==
1
assert
budget
.
num_curr_seqs
==
1
...
@@ -779,8 +758,6 @@ def test_schedule_swapped_max_loras():
...
@@ -779,8 +758,6 @@ def test_schedule_swapped_max_loras():
def
test_schedule_swapped_cannot_swap_in
():
def
test_schedule_swapped_cannot_swap_in
():
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
_
in
range
(
2
):
for
_
in
range
(
2
):
...
@@ -788,15 +765,15 @@ def test_schedule_swapped_cannot_swap_in():
...
@@ -788,15 +765,15 @@ def test_schedule_swapped_cannot_swap_in():
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
s
wapped
.
appe
n
d
(
seq_group
)
s
cheduler
.
_add_seq_group_to_sw
apped
(
seq_group
)
# The last request should be swapped out.
# The last request should be swapped out.
scheduler
.
block_manager
.
can_swap_in
=
MagicMock
()
scheduler
.
block_manager
.
can_swap_in
=
MagicMock
()
scheduler
.
block_manager
.
can_swap_in
.
return_value
=
AllocStatus
.
LATER
scheduler
.
block_manager
.
can_swap_in
.
return_value
=
AllocStatus
.
LATER
# Since we cannot swap in, none of the requests are swapped in.
# Since we cannot swap in, none of the requests are swapped in.
budget
=
create_token_budget
()
budget
=
create_token_budget
()
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
swapped
,
budget
,
curr_loras
,
policy
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
2
assert
len
(
remaining_swapped
)
==
2
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_curr_seqs
==
0
assert
budget
.
num_curr_seqs
==
0
...
@@ -806,8 +783,6 @@ def test_schedule_swapped_cannot_swap_in():
...
@@ -806,8 +783,6 @@ def test_schedule_swapped_cannot_swap_in():
def
test_infeasible_swap
():
def
test_infeasible_swap
():
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
_
in
range
(
2
):
for
_
in
range
(
2
):
...
@@ -815,15 +790,15 @@ def test_infeasible_swap():
...
@@ -815,15 +790,15 @@ def test_infeasible_swap():
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
s
wapped
.
appe
n
d
(
seq_group
)
s
cheduler
.
_add_seq_group_to_sw
apped
(
seq_group
)
# The last request should be swapped out.
# The last request should be swapped out.
scheduler
.
block_manager
.
can_swap_in
=
MagicMock
()
scheduler
.
block_manager
.
can_swap_in
=
MagicMock
()
scheduler
.
block_manager
.
can_swap_in
.
return_value
=
AllocStatus
.
NEVER
scheduler
.
block_manager
.
can_swap_in
.
return_value
=
AllocStatus
.
NEVER
# Since we cannot swap in, none of the requests are swapped in.
# Since we cannot swap in, none of the requests are swapped in.
budget
=
create_token_budget
()
budget
=
create_token_budget
()
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
swapped
,
budget
,
curr_loras
,
policy
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
0
assert
len
(
remaining_swapped
)
==
0
assert
len
(
output
.
infeasible_seq_groups
)
==
2
assert
len
(
output
.
infeasible_seq_groups
)
==
2
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_batched_tokens
==
0
...
@@ -834,23 +809,21 @@ def test_infeasible_swap():
...
@@ -834,23 +809,21 @@ def test_infeasible_swap():
def
test_schedule_swapped_blocks_to_copy
():
def
test_schedule_swapped_blocks_to_copy
():
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
s
wapped
.
appe
n
d
(
seq_group
)
s
cheduler
.
_add_seq_group_to_sw
apped
(
seq_group
)
# The last request should be swapped out.
# The last request should be swapped out.
scheduler
.
block_manager
.
append_slots
=
MagicMock
()
scheduler
.
block_manager
.
append_slots
=
MagicMock
()
scheduler
.
block_manager
.
append_slots
.
return_value
=
[(
2
,
3
)]
scheduler
.
block_manager
.
append_slots
.
return_value
=
[(
2
,
3
)]
budget
=
create_token_budget
()
budget
=
create_token_budget
()
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
swapped
,
budget
,
curr_loras
,
policy
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
0
assert
len
(
remaining_swapped
)
==
0
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
prefill_seq_groups
)
==
0
assert
len
(
output
.
prefill_seq_groups
)
==
0
...
...
tests/distributed/test_basic_distributed_correctness.py
View file @
e661d594
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
vLLM will allocate all the available memory, so we need to run the tests one
by one. The solution is to pass arguments (model name) by environment
variables.
Run:
Run:
```sh
```sh
cd $VLLM_PATH/tests
cd $VLLM_PATH/tests
TEST_DIST_MODEL=facebook/opt-125m pytest
\
pytest distributed/test_basic_distributed_correctness.py
distributed/test_basic_distributed_correctness.py
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf
\
distributed/test_basic_distributed_correctness.py
```
```
"""
"""
import
os
import
os
...
@@ -19,27 +14,48 @@ import pytest
...
@@ -19,27 +14,48 @@ import pytest
from
vllm.utils
import
cuda_device_count_stateless
from
vllm.utils
import
cuda_device_count_stateless
from
..models.utils
import
check_outputs_equal
from
..models.utils
import
check_outputs_equal
from
..utils
import
fork_new_process_for_each_test
MODELS
=
[
TARGET_TEST_SUITE
=
os
.
environ
.
get
(
"TARGET_TEST_SUITE"
,
"L4"
)
os
.
environ
[
"TEST_DIST_MODEL"
],
]
DISTRIBUTED_EXECUTOR_BACKEND
=
"DISTRIBUTED_EXECUTOR_BACKEND"
@
pytest
.
mark
.
skipif
(
cuda_device_count_stateless
()
<
2
,
@
pytest
.
mark
.
skipif
(
cuda_device_count_stateless
()
<
2
,
reason
=
"Need at least 2 GPUs to run the test."
)
reason
=
"Need at least 2 GPUs to run the test."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
"model, distributed_executor_backend, attention_backend, test_suite"
,
[
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
(
"facebook/opt-125m"
,
"ray"
,
""
,
"L4"
),
(
"facebook/opt-125m"
,
"mp"
,
""
,
"L4"
),
(
"meta-llama/Llama-2-7b-hf"
,
"ray"
,
""
,
"L4"
),
(
"meta-llama/Llama-2-7b-hf"
,
"mp"
,
""
,
"L4"
),
(
"facebook/opt-125m"
,
"ray"
,
""
,
"A100"
),
(
"facebook/opt-125m"
,
"mp"
,
""
,
"A100"
),
(
"facebook/opt-125m"
,
"mp"
,
"FLASHINFER"
,
"A100"
),
(
"meta-llama/Meta-Llama-3-8B"
,
"ray"
,
"FLASHINFER"
,
"A100"
),
])
@
fork_new_process_for_each_test
def
test_models
(
def
test_models
(
hf_runner
,
hf_runner
,
vllm_runner
,
vllm_runner
,
example_prompts
,
example_prompts
,
model
:
str
,
model
:
str
,
dtype
:
str
,
distributed_executor_backend
:
str
,
max_tokens
:
int
,
attention_backend
:
str
,
test_suite
:
str
,
)
->
None
:
)
->
None
:
distributed_executor_backend
=
os
.
getenv
(
DISTRIBUTED_EXECUTOR_BACKEND
)
if
test_suite
!=
TARGET_TEST_SUITE
:
pytest
.
skip
(
f
"Skip test for
{
test_suite
}
"
)
if
model
==
"meta-llama/Llama-2-7b-hf"
and
distributed_executor_backend
==
"ray"
and
attention_backend
==
""
and
test_suite
==
"L4"
:
# noqa
# test ray adag
os
.
environ
[
'VLLM_USE_RAY_SPMD_WORKER'
]
=
"1"
os
.
environ
[
'VLLM_USE_RAY_COMPILED_DAG'
]
=
"1"
if
attention_backend
:
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
attention_backend
dtype
=
"half"
max_tokens
=
5
# NOTE: take care of the order. run vLLM first, and then run HF.
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# vLLM needs a fresh new process without cuda initialization.
...
...
tests/distributed/test_chunked_prefill_distributed.py
View file @
e661d594
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
vLLM will allocate all the available memory, so we need to run the tests one
by one. The solution is to pass arguments (model name) by environment
variables.
Run:
Run:
```sh
```sh
TEST_DIST_MODEL=facebook/opt-125m pytest
\
pytest test_chunked_prefill_distributed.py
test_chunked_prefill_distributed.py
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf
\
test_chunked_prefill_distributed.py
```
```
"""
"""
import
os
import
pytest
import
pytest
from
vllm.utils
import
cuda_device_count_stateless
from
vllm.utils
import
cuda_device_count_stateless
from
..models.utils
import
check_outputs_equal
from
..models.utils
import
check_outputs_equal
from
..utils
import
fork_new_process_for_each_test
MODELS
=
[
os
.
environ
[
"TEST_DIST_MODEL"
],
]
DISTRIBUTED_EXECUTOR_BACKEND
=
"DISTRIBUTED_EXECUTOR_BACKEND"
@
pytest
.
mark
.
skipif
(
cuda_device_count_stateless
()
<
2
,
@
pytest
.
mark
.
skipif
(
cuda_device_count_stateless
()
<
2
,
reason
=
"Need at least 2 GPUs to run the test."
)
reason
=
"Need at least 2 GPUs to run the test."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model, distributed_executor_backend"
,
[
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
(
"facebook/opt-125m"
,
"ray"
),
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
(
"meta-llama/Llama-2-7b-hf"
,
"ray"
),
@
pytest
.
mark
.
parametrize
(
"chunked_prefill_token_size"
,
[
16
])
(
"facebook/opt-125m"
,
"mp"
),
(
"meta-llama/Llama-2-7b-hf"
,
"mp"
),
])
@
fork_new_process_for_each_test
def
test_models
(
def
test_models
(
hf_runner
,
hf_runner
,
vllm_runner
,
vllm_runner
,
example_prompts
,
example_prompts
,
model
:
str
,
model
:
str
,
dtype
:
str
,
distributed_executor_backend
:
str
,
max_tokens
:
int
,
chunked_prefill_token_size
:
int
,
)
->
None
:
)
->
None
:
distributed_executor_backend
=
os
.
getenv
(
DISTRIBUTED_EXECUTOR_BACKEND
)
dtype
=
"half"
max_tokens
=
5
chunked_prefill_token_size
=
16
# Add a chunked prefill config.
# Add a chunked prefill config.
max_num_seqs
=
min
(
chunked_prefill_token_size
,
256
)
max_num_seqs
=
min
(
chunked_prefill_token_size
,
256
)
...
...
tests/distributed/test_multimodal_broadcast.py
View file @
e661d594
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
The second test will hang if more than one test is run per command, so we need
to run the tests one by one. The solution is to pass arguments (model name) by
environment variables.
Run:
Run:
```sh
```sh
TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf
\
pytest -s -v test_multimodal_broadcast.py
test_multimodal_broadcast.py
TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct
\
test_multimodal_broadcast.py
```
```
"""
"""
import
os
import
pytest
import
pytest
from
vllm.utils
import
cuda_device_count_stateless
from
vllm.utils
import
cuda_device_count_stateless
model
=
os
.
environ
[
"TEST_DIST_MODEL"
]
from
..utils
import
fork_new_process_for_each_test
if
model
.
startswith
(
"llava-hf/llava"
):
from
..models.test_llava
import
models
,
run_test
@
pytest
.
mark
.
skipif
(
cuda_device_count_stateless
()
<
2
,
elif
model
.
startswith
(
"microsoft/Phi-3-vision"
):
reason
=
"Need at least 2 GPUs to run the test."
)
from
..models.test_phi3v
import
models
,
run_test
@
pytest
.
mark
.
parametrize
(
"model, distributed_executor_backend"
,
[
else
:
(
"llava-hf/llava-1.5-7b-hf"
,
"ray"
),
raise
NotImplementedError
(
f
"Unsupported model:
{
model
}
"
)
(
"llava-hf/llava-v1.6-mistral-7b-hf"
,
"ray"
),
(
"llava-hf/llava-1.5-7b-hf"
,
"mp"
),
(
"llava-hf/llava-v1.6-mistral-7b-hf"
,
"mp"
),
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
2
])
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
fork_new_process_for_each_test
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
def
test_models
(
hf_runner
,
vllm_runner
,
image_assets
,
model
:
str
,
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
distributed_executor_backend
:
str
)
->
None
:
def
test_models
(
hf_runner
,
vllm_runner
,
image_assets
,
tensor_parallel_size
:
int
,
dtype
:
str
,
max_tokens
:
int
,
dtype
=
"half"
num_logprobs
:
int
)
->
None
:
max_tokens
=
5
if
cuda_device_count_stateless
()
<
tensor_parallel_size
:
num_logprobs
=
5
pytest
.
skip
(
tensor_parallel_size
=
2
f
"Need at least
{
tensor_parallel_size
}
GPUs to run the test."
)
if
model
.
startswith
(
"llava-hf/llava-1.5"
):
distributed_executor_backend
=
os
.
getenv
(
"DISTRIBUTED_EXECUTOR_BACKEND"
)
from
..models.test_llava
import
models
,
run_test
elif
model
.
startswith
(
"llava-hf/llava-v1.6"
):
from
..models.test_llava_next
import
models
,
run_test
else
:
raise
NotImplementedError
(
f
"Unsupported model:
{
model
}
"
)
run_test
(
run_test
(
hf_runner
,
hf_runner
,
vllm_runner
,
vllm_runner
,
image_assets
,
image_assets
,
model
=
models
[
0
],
model
=
models
[
0
],
size_factors
=
[
1.0
],
# So that LLaVA-NeXT processor may return nested list
size_factors
=
[
0.25
,
0.5
,
1.0
],
dtype
=
dtype
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
num_logprobs
=
num_logprobs
,
...
...
tests/distributed/test_parallel_state.py
deleted
100644 → 0
View file @
6b16ea2e
from
typing
import
Any
,
Dict
import
pytest
import
torch
from
vllm.distributed.parallel_state
import
(
_split_tensor_dict
,
_update_nested_dict
)
def
test_split_tensor_dict
():
test_dict
=
{
"key_a"
:
"a"
,
"key_b"
:
torch
.
arange
(
8
,
dtype
=
torch
.
float32
),
"key_c"
:
{
"key_1"
:
torch
.
arange
(
5
,
dtype
=
torch
.
float32
),
"key_2"
:
torch
.
tensor
([],
dtype
=
torch
.
float32
),
"key_3"
:
123
,
},
"key_d"
:
{},
}
metadata_list
,
tensor_list
=
_split_tensor_dict
(
test_dict
)
assert
len
(
metadata_list
)
==
6
assert
torch
.
allclose
(
tensor_list
[
0
],
test_dict
[
"key_b"
])
assert
torch
.
allclose
(
tensor_list
[
1
],
test_dict
[
"key_c"
][
"key_1"
])
assert
torch
.
allclose
(
tensor_list
[
2
],
test_dict
[
"key_c"
][
"key_2"
])
def
test_split_tensor_dict_invalid_key
():
test_dict
=
{
"a%b"
:
"a"
,
}
with
pytest
.
raises
(
AssertionError
):
_split_tensor_dict
(
test_dict
)
def
test_update_nested_dict
():
flattened_keys_values
=
[(
"key1%key2%key3"
,
"value1"
),
(
"key1%key2%key4"
,
"value2"
),
(
"key1%key5"
,
"value3"
),
(
"key6%key7"
,
"value4"
),
(
"key8"
,
"value5"
)]
res
:
Dict
[
str
,
Any
]
=
{}
for
flat_key
,
value
in
flattened_keys_values
:
_update_nested_dict
(
res
,
flat_key
,
value
)
assert
res
==
{
"key1"
:
{
"key2"
:
{
"key3"
:
"value1"
,
"key4"
:
"value2"
},
"key5"
:
"value3"
},
"key6"
:
{
"key7"
:
"value4"
},
"key8"
:
"value5"
}
tests/distributed/test_pipeline_parallel.py
View file @
e661d594
"""
WARNING: This test runs in both single-node (4 GPUs) and multi-node
(2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
important to set the distributed backend to "mp" to avoid Ray scheduling
all workers in a node other than the head node, which can cause the test
to fail.
"""
import
os
import
os
import
pytest
import
pytest
from
..utils
import
compare_two_settings
from
..utils
import
compare_two_settings
,
fork_new_process_for_each_test
VLLM_MULTI_NODE
=
os
.
getenv
(
"VLLM_MULTI_NODE"
,
"0"
)
==
"1"
VLLM_MULTI_NODE
=
os
.
getenv
(
"VLLM_MULTI_NODE"
,
"0"
)
==
"1"
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, "
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
MODEL_NAME, DIST_BACKEND"
,
"
MODEL_NAME, DIST_BACKEND"
)
,
[
[
(
2
,
2
,
0
,
1
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
2
,
2
,
0
,
1
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
2
,
2
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
2
,
2
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
1
,
3
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
1
,
3
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
1
,
4
,
0
,
1
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
1
,
4
,
0
,
1
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
1
,
4
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
1
,
4
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
2
,
2
,
0
,
1
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
2
,
2
,
0
,
1
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
2
,
2
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
2
,
2
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
1
,
3
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
1
,
3
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
1
,
4
,
0
,
1
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
1
,
4
,
0
,
1
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
1
,
4
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
1
,
4
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
])
])
def
test_compare_tp
(
TP_SIZE
,
PP_SIZE
,
EAGER_MODE
,
CHUNKED_PREFILL
,
MODEL_NAME
,
def
test_compare_tp
(
TP_SIZE
,
PP_SIZE
,
EAGER_MODE
,
CHUNKED_PREFILL
,
MODEL_NAME
,
DIST_BACKEND
):
DIST_BACKEND
):
if
VLLM_MULTI_NODE
and
DIST_BACKEND
==
"mp"
:
if
VLLM_MULTI_NODE
and
DIST_BACKEND
==
"mp"
:
pytest
.
skip
(
"Skipping multi-node pipeline parallel test for "
pytest
.
skip
(
"Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend"
)
"multiprocessing distributed backend"
)
USE_RAY_ADAG_NCCL
=
0
USE_RAY_ADAG
=
0
pp_args
=
[
pp_args
=
[
# use half precision for speed and memory savings in CI environment
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"--dtype"
,
...
@@ -59,5 +69,40 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
...
@@ -59,5 +69,40 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
if
EAGER_MODE
:
if
EAGER_MODE
:
pp_args
.
append
(
"--enforce-eager"
)
pp_args
.
append
(
"--enforce-eager"
)
tp_args
.
append
(
"--enforce-eager"
)
tp_args
.
append
(
"--enforce-eager"
)
pp_env
=
None
if
USE_RAY_ADAG
:
assert
DIST_BACKEND
==
"ray"
,
(
"Ray ADAG is only supported with Ray distributed backend"
)
pp_env
=
{
"VLLM_USE_RAY_COMPILED_DAG"
:
"1"
,
"VLLM_USE_RAY_SPMD_WORKER"
:
"1"
,
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL"
:
str
(
int
(
USE_RAY_ADAG_NCCL
)),
}
compare_two_settings
(
MODEL_NAME
,
pp_args
,
tp_args
,
pp_env
)
@
pytest
.
mark
.
parametrize
(
"PP_SIZE, MODEL_NAME"
,
[
(
2
,
"JackFram/llama-160m"
),
])
@
pytest
.
mark
.
parametrize
(
"ATTN_BACKEND"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
])
@
fork_new_process_for_each_test
def
test_pp_cudagraph
(
PP_SIZE
,
MODEL_NAME
,
ATTN_BACKEND
):
cudagraph_args
=
[
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"float16"
,
"--pipeline-parallel-size"
,
str
(
PP_SIZE
),
"--distributed-executor-backend"
,
"mp"
,
]
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
ATTN_BACKEND
eager_args
=
cudagraph_args
+
[
"--enforce-eager"
]
compare_two_settings
(
MODEL_NAME
,
pp
_args
,
tp
_args
)
compare_two_settings
(
MODEL_NAME
,
eager
_args
,
cudagraph
_args
)
tests/distributed/test_pipeline_partition.py
0 → 100644
View file @
e661d594
import
os
import
pytest
from
vllm.distributed.utils
import
get_pp_indices
def
test_custom_layer_partition
():
def
_verify
(
partition_str
,
num_layers
,
pp_size
,
goldens
):
bak
=
os
.
environ
.
get
(
"VLLM_PP_LAYER_PARTITION"
,
None
)
os
.
environ
[
"VLLM_PP_LAYER_PARTITION"
]
=
partition_str
for
pp_rank
,
golden
in
enumerate
(
goldens
):
assert
get_pp_indices
(
num_layers
,
pp_rank
,
pp_size
)
==
golden
if
bak
is
not
None
:
os
.
environ
[
"VLLM_PP_LAYER_PARTITION"
]
=
bak
# Even partition
_verify
(
"5,5,5,5"
,
20
,
4
,
[(
0
,
5
),
(
5
,
10
),
(
10
,
15
),
(
15
,
20
)])
# Balanced partition
_verify
(
"4,6,6,4"
,
20
,
4
,
[(
0
,
4
),
(
4
,
10
),
(
10
,
16
),
(
16
,
20
)])
# Put reminder somewhere
_verify
(
"5,6,5,6"
,
22
,
4
,
[(
0
,
5
),
(
5
,
11
),
(
11
,
16
),
(
16
,
22
)])
# Invalid partition strings
with
pytest
.
raises
(
ValueError
):
_verify
(
"5,5,5,5,"
,
20
,
4
,
[(
0
,
5
),
(
5
,
10
),
(
10
,
15
),
(
15
,
20
)])
with
pytest
.
raises
(
ValueError
):
_verify
(
"5,5,5,a"
,
20
,
4
,
[(
0
,
5
),
(
5
,
10
),
(
10
,
15
),
(
15
,
20
)])
# Wrong number of partitions
with
pytest
.
raises
(
ValueError
):
_verify
(
"5,5,5"
,
20
,
4
,
[(
0
,
5
),
(
5
,
10
),
(
10
,
15
),
(
15
,
20
)])
# Wrong number of layers
with
pytest
.
raises
(
ValueError
):
_verify
(
"5,5,5,5"
,
21
,
4
,
[(
0
,
5
),
(
5
,
10
),
(
10
,
15
),
(
15
,
20
)])
tests/entrypoints/
openai/
conftest.py
→
tests/entrypoints/conftest.py
View file @
e661d594
import
pytest
import
pytest
@
pytest
.
fixture
def
sample_prompts
():
return
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
@
pytest
.
fixture
def
sample_token_ids
():
return
[
[
0
],
[
0
,
1
],
[
0
,
2
,
1
],
[
0
,
3
,
1
,
2
],
]
@
pytest
.
fixture
@
pytest
.
fixture
def
sample_regex
():
def
sample_regex
():
return
(
r
"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
return
(
r
"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
...
@@ -66,4 +86,4 @@ column: "col_1" | "col_2"
...
@@ -66,4 +86,4 @@ column: "col_1" | "col_2"
table: "table_1" | "table_2"
table: "table_1" | "table_2"
condition: column "=" number
condition: column "=" number
number: "1" | "2"
number: "1" | "2"
"""
)
"""
)
\ No newline at end of file
tests/entrypoints/llm/test_guided_generate.py
0 → 100644
View file @
e661d594
import
json
import
re
import
weakref
import
jsonschema
import
pytest
from
vllm.entrypoints.llm
import
LLM
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
...conftest
import
cleanup
MODEL_NAME
=
"HuggingFaceH4/zephyr-7b-beta"
@
pytest
.
fixture
(
scope
=
"module"
)
def
llm
():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm
=
LLM
(
model
=
MODEL_NAME
,
max_model_len
=
1024
)
with
llm
.
deprecate_legacy_api
():
yield
weakref
.
proxy
(
llm
)
del
llm
cleanup
()
@
pytest
.
mark
.
skip_global_cleanup
def
test_guided_regex
(
sample_regex
,
llm
):
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
)
outputs
=
llm
.
generate
(
prompts
=
[
f
"Give an example IPv4 address with this regex:
{
sample_regex
}
"
]
*
2
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
,
guided_options_request
=
dict
(
guided_regex
=
sample_regex
))
assert
outputs
is
not
None
for
output
in
outputs
:
assert
output
is
not
None
assert
isinstance
(
output
,
RequestOutput
)
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
generated_text
)
assert
generated_text
is
not
None
assert
re
.
fullmatch
(
sample_regex
,
generated_text
)
is
not
None
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_guided_json_completion
(
sample_json_schema
,
llm
):
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
max_tokens
=
1000
,
)
outputs
=
llm
.
generate
(
prompts
=
[
f
"Give an example JSON for an employee profile "
f
"that fits this schema:
{
sample_json_schema
}
"
]
*
2
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
,
guided_options_request
=
dict
(
guided_json
=
sample_json_schema
))
assert
outputs
is
not
None
for
output
in
outputs
:
assert
output
is
not
None
assert
isinstance
(
output
,
RequestOutput
)
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
assert
generated_text
is
not
None
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
output_json
=
json
.
loads
(
generated_text
)
jsonschema
.
validate
(
instance
=
output_json
,
schema
=
sample_json_schema
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_guided_choice_completion
(
sample_guided_choice
,
llm
):
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
)
outputs
=
llm
.
generate
(
prompts
=
"The best language for type-safe systems programming is "
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
,
guided_options_request
=
dict
(
guided_choice
=
sample_guided_choice
))
assert
outputs
is
not
None
for
output
in
outputs
:
assert
output
is
not
None
assert
isinstance
(
output
,
RequestOutput
)
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
generated_text
)
assert
generated_text
is
not
None
assert
generated_text
in
sample_guided_choice
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_guided_grammar
(
sample_sql_statements
,
llm
):
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
max_tokens
=
1000
,
)
outputs
=
llm
.
generate
(
prompts
=
(
"Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"
),
sampling_params
=
sampling_params
,
use_tqdm
=
True
,
guided_options_request
=
dict
(
guided_grammar
=
sample_sql_statements
))
assert
outputs
is
not
None
for
output
in
outputs
:
assert
output
is
not
None
assert
isinstance
(
output
,
RequestOutput
)
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
assert
generated_text
is
not
None
# use Lark to parse the output, and make sure it's a valid parse tree
from
lark
import
Lark
parser
=
Lark
(
sample_sql_statements
)
parser
.
parse
(
generated_text
)
# remove spaces for comparison b/c we removed them in the grammar
ground_truth
=
"SELECT col_1 from table_1 where col_1 = 1"
.
replace
(
" "
,
""
)
assert
generated_text
.
strip
()
==
ground_truth
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
tests/entrypoints/openai/test_chat.py
View file @
e661d594
...
@@ -295,14 +295,19 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
...
@@ -295,14 +295,19 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
async
for
chunk
in
stream
:
async
for
chunk
in
stream
:
assert
chunk
.
usage
is
None
assert
chunk
.
usage
is
None
# Test stream=True, stream_options={"include_usage": True}
# Test stream=True, stream_options={"include_usage": True,
stream
=
await
client
.
chat
.
completions
.
create
(
# "continuous_usage_stats": False}}
model
=
model_name
,
stream
=
await
client
.
chat
.
completions
.
create
(
model
=
model_name
,
messages
=
messages
,
messages
=
messages
,
max_tokens
=
10
,
max_tokens
=
10
,
temperature
=
0.0
,
temperature
=
0.0
,
stream
=
True
,
stream
=
True
,
stream_options
=
{
"include_usage"
:
True
})
stream_options
=
{
"include_usage"
:
True
,
"continuous_usage_stats"
:
False
})
async
for
chunk
in
stream
:
async
for
chunk
in
stream
:
if
chunk
.
choices
[
0
].
finish_reason
is
None
:
if
chunk
.
choices
[
0
].
finish_reason
is
None
:
...
@@ -338,6 +343,25 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
...
@@ -338,6 +343,25 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
stream
=
False
,
stream
=
False
,
stream_options
=
{
"include_usage"
:
True
})
stream_options
=
{
"include_usage"
:
True
})
# Test stream=True, stream_options={"include_usage": True,
# "continuous_usage_stats": True}
stream
=
await
client
.
chat
.
completions
.
create
(
model
=
model_name
,
messages
=
messages
,
max_tokens
=
10
,
temperature
=
0.0
,
stream
=
True
,
stream_options
=
{
"include_usage"
:
True
,
"continuous_usage_stats"
:
True
},
)
async
for
chunk
in
stream
:
assert
chunk
.
usage
.
prompt_tokens
>=
0
assert
chunk
.
usage
.
completion_tokens
>=
0
assert
chunk
.
usage
.
total_tokens
==
(
chunk
.
usage
.
prompt_tokens
+
chunk
.
usage
.
completion_tokens
)
# NOTE: Not sure why, but when I place this after `test_guided_regex_chat`
# NOTE: Not sure why, but when I place this after `test_guided_regex_chat`
# (i.e. using the same ordering as in the Completions API tests), the test
# (i.e. using the same ordering as in the Completions API tests), the test
...
...
tests/entrypoints/openai/test_completion.py
View file @
e661d594
...
@@ -55,8 +55,9 @@ def zephyr_pa_files():
...
@@ -55,8 +55,9 @@ def zephyr_pa_files():
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
(
zephyr_lora_files
,
zephyr_lora_added_tokens_files
,
zephyr_pa_files
):
def
default_server_args
(
zephyr_lora_files
,
zephyr_lora_added_tokens_files
,
args
=
[
zephyr_pa_files
):
return
[
# use half precision for speed and memory savings in CI environment
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"--dtype"
,
"bfloat16"
,
"bfloat16"
,
...
@@ -85,7 +86,10 @@ def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files):
...
@@ -85,7 +86,10 @@ def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files):
"128"
,
"128"
,
]
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
(
default_server_args
):
with
RemoteOpenAIServer
(
MODEL_NAME
,
default_server_args
)
as
remote_server
:
yield
remote_server
yield
remote_server
...
@@ -537,6 +541,28 @@ async def test_logits_bias(client: openai.AsyncOpenAI):
...
@@ -537,6 +541,28 @@ async def test_logits_bias(client: openai.AsyncOpenAI):
assert
first_response
!=
completion
.
choices
[
0
].
text
assert
first_response
!=
completion
.
choices
[
0
].
text
@
pytest
.
mark
.
asyncio
async
def
test_allowed_token_ids
(
client
:
openai
.
AsyncOpenAI
):
prompt
=
"Hello, my name is"
max_tokens
=
1
tokenizer
=
get_tokenizer
(
tokenizer_name
=
MODEL_NAME
)
# Test exclusive selection
allowed_ids
=
[
21555
,
21557
,
21558
]
completion
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
prompt
,
max_tokens
=
max_tokens
,
temperature
=
0.0
,
seed
=
42
,
extra_body
=
dict
(
allowed_token_ids
=
allowed_ids
),
logprobs
=
1
,
)
response_tokens
=
completion
.
choices
[
0
].
logprobs
.
tokens
assert
len
(
response_tokens
)
==
1
assert
tokenizer
.
convert_tokens_to_ids
(
response_tokens
)[
0
]
in
allowed_ids
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
[
"outlines"
,
"lm-format-enforcer"
])
[
"outlines"
,
"lm-format-enforcer"
])
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment