Commit cc7f22a8 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.1' into v0.9.1-ori

parents b9ea0c09 b6553be1
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the SamplingParams class. """Tests for the SamplingParams class.
""" """
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
import torch import torch
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random import random
import numpy as np import numpy as np
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing as mp import multiprocessing as mp
import os import os
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys import sys
import types import types
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa # ruff: noqa
import asyncio import asyncio
...@@ -17,7 +18,8 @@ from vllm_test_utils.monitor import monitor ...@@ -17,7 +18,8 @@ from vllm_test_utils.monitor import monitor
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache, from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
MemorySnapshot, PlaceholderModule, StoreBoolean, MemorySnapshot, PlaceholderModule, StoreBoolean,
bind_kv_cache, deprecate_kwargs, get_open_port, bind_kv_cache, common_broadcastable_dtype,
deprecate_kwargs, get_open_port, is_lossless_cast,
make_zmq_path, make_zmq_socket, memory_profiling, make_zmq_path, make_zmq_socket, memory_profiling,
merge_async_iterators, sha256, split_zmq_path, merge_async_iterators, sha256, split_zmq_path,
supports_kw, swap_dict_values) supports_kw, swap_dict_values)
...@@ -258,11 +260,18 @@ def test_dict_args(parser): ...@@ -258,11 +260,18 @@ def test_dict_args(parser):
"--model-name=something.something", "--model-name=something.something",
"--hf-overrides.key1", "--hf-overrides.key1",
"val1", "val1",
# Test nesting
"--hf-overrides.key2.key3", "--hf-overrides.key2.key3",
"val2", "val2",
"--hf-overrides.key2.key4", "--hf-overrides.key2.key4",
"val3", "val3",
# Test = sign
"--hf-overrides.key5=val4", "--hf-overrides.key5=val4",
# Test underscore to dash conversion
"--hf_overrides.key_6",
"val5",
"--hf_overrides.key-7.key_8",
"val6",
] ]
parsed_args = parser.parse_args(args) parsed_args = parser.parse_args(args)
assert parsed_args.model_name == "something.something" assert parsed_args.model_name == "something.something"
...@@ -273,6 +282,10 @@ def test_dict_args(parser): ...@@ -273,6 +282,10 @@ def test_dict_args(parser):
"key4": "val3", "key4": "val3",
}, },
"key5": "val4", "key5": "val4",
"key_6": "val5",
"key-7": {
"key_8": "val6",
},
} }
...@@ -567,12 +580,65 @@ def test_lru_cache(): ...@@ -567,12 +580,65 @@ def test_lru_cache():
assert 6 in cache assert 6 in cache
# yapf: disable
@pytest.mark.parametrize(
("src_dtype", "tgt_dtype", "expected_result"),
[
# Different precision_levels
(torch.bool, torch.int8, True),
(torch.bool, torch.float16, True),
(torch.bool, torch.complex32, True),
(torch.int64, torch.bool, False),
(torch.int64, torch.float16, True),
(torch.int64, torch.complex32, True),
(torch.float64, torch.bool, False),
(torch.float64, torch.int8, False),
(torch.float64, torch.complex32, True),
(torch.complex128, torch.bool, False),
(torch.complex128, torch.int8, False),
(torch.complex128, torch.float16, False),
# precision_level=0
(torch.bool, torch.bool, True),
# precision_level=1
(torch.int8, torch.int16, True),
(torch.int16, torch.int8, False),
(torch.uint8, torch.int8, False),
(torch.int8, torch.uint8, False),
# precision_level=2
(torch.float16, torch.float32, True),
(torch.float32, torch.float16, False),
(torch.bfloat16, torch.float32, True),
(torch.float32, torch.bfloat16, False),
# precision_level=3
(torch.complex32, torch.complex64, True),
(torch.complex64, torch.complex32, False),
],
)
# yapf: enable
def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result):
assert is_lossless_cast(src_dtype, tgt_dtype) == expected_result
# yapf: disable
@pytest.mark.parametrize(
("dtypes", "expected_result"),
[
([torch.bool], torch.bool),
([torch.bool, torch.int8], torch.int8),
([torch.bool, torch.int8, torch.float16], torch.float16),
([torch.bool, torch.int8, torch.float16, torch.complex32], torch.complex32), # noqa: E501
],
)
# yapf: enable
def test_common_broadcastable_dtype(dtypes, expected_result):
assert common_broadcastable_dtype(dtypes) == expected_result
def test_placeholder_module_error_handling(): def test_placeholder_module_error_handling():
placeholder = PlaceholderModule("placeholder_1234") placeholder = PlaceholderModule("placeholder_1234")
def build_ctx(): def build_ctx():
return pytest.raises(ModuleNotFoundError, return pytest.raises(ModuleNotFoundError, match="No module named")
match="No module named")
with build_ctx(): with build_ctx():
int(placeholder) int(placeholder)
...@@ -608,6 +674,7 @@ def test_placeholder_module_error_handling(): ...@@ -608,6 +674,7 @@ def test_placeholder_module_error_handling():
_ = placeholder_attr.module _ = placeholder_attr.module
# yapf: disable
@pytest.mark.parametrize( @pytest.mark.parametrize(
"obj,key1,key2", "obj,key1,key2",
[ [
...@@ -618,6 +685,7 @@ def test_placeholder_module_error_handling(): ...@@ -618,6 +685,7 @@ def test_placeholder_module_error_handling():
# Tests for both keys do not exist # Tests for both keys do not exist
({1: "a", 2: "b"}, 3, 4), ({1: "a", 2: "b"}, 3, 4),
]) ])
# yapf: enable
def test_swap_dict_values(obj, key1, key2): def test_swap_dict_values(obj, key1, key2):
original_obj = obj.copy() original_obj = obj.copy()
swap_dict_values(obj, key1, key2) swap_dict_values(obj, key1, key2)
...@@ -631,19 +699,19 @@ def test_swap_dict_values(obj, key1, key2): ...@@ -631,19 +699,19 @@ def test_swap_dict_values(obj, key1, key2):
assert key1 not in obj assert key1 not in obj
def test_model_specification(parser_with_config, def test_model_specification(parser_with_config, cli_config_file,
cli_config_file,
cli_config_file_with_model): cli_config_file_with_model):
# Test model in CLI takes precedence over config # Test model in CLI takes precedence over config
args = parser_with_config.parse_args([ args = parser_with_config.parse_args(
'serve', 'cli-model', '--config', cli_config_file_with_model ['serve', 'cli-model', '--config', cli_config_file_with_model])
])
assert args.model_tag == 'cli-model' assert args.model_tag == 'cli-model'
assert args.served_model_name == 'mymodel' assert args.served_model_name == 'mymodel'
# Test model from config file works # Test model from config file works
args = parser_with_config.parse_args([ args = parser_with_config.parse_args([
'serve', '--config', cli_config_file_with_model, 'serve',
'--config',
cli_config_file_with_model,
]) ])
assert args.model == 'config-model' assert args.model == 'config-model'
assert args.served_model_name == 'mymodel' assert args.served_model_name == 'mymodel'
...@@ -655,16 +723,18 @@ def test_model_specification(parser_with_config, ...@@ -655,16 +723,18 @@ def test_model_specification(parser_with_config,
# Test using --model option raises error # Test using --model option raises error
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match=( match=
"With `vllm serve`, you should provide the model as a positional " ("With `vllm serve`, you should provide the model as a positional "
"argument or in a config file instead of via the `--model` option." "argument or in a config file instead of via the `--model` option."),
),
): ):
parser_with_config.parse_args(['serve', '--model', 'my-model']) parser_with_config.parse_args(['serve', '--model', 'my-model'])
# Test other config values are preserved # Test other config values are preserved
args = parser_with_config.parse_args([ args = parser_with_config.parse_args([
'serve', 'cli-model', '--config', cli_config_file_with_model, 'serve',
'cli-model',
'--config',
cli_config_file_with_model,
]) ])
assert args.tensor_parallel_size == 2 assert args.tensor_parallel_size == 2
assert args.trust_remote_code is True assert args.trust_remote_code is True
...@@ -682,7 +752,8 @@ def test_sha256(input: tuple, output: int): ...@@ -682,7 +752,8 @@ def test_sha256(input: tuple, output: int):
assert hash != 0 assert hash != 0
bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), byteorder="big") assert hash == int.from_bytes(hashlib.sha256(bytes).digest(),
byteorder="big")
# hashing again, returns the same value # hashing again, returns the same value
assert hash == sha256(input) assert hash == sha256(input)
...@@ -698,8 +769,7 @@ def test_sha256(input: tuple, output: int): ...@@ -698,8 +769,7 @@ def test_sha256(input: tuple, output: int):
("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")), ("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")),
("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address ("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address
("inproc://some_identifier", ("inproc", "some_identifier", "")), ("inproc://some_identifier", ("inproc", "some_identifier", "")),
] ])
)
def test_split_zmq_path(path, expected): def test_split_zmq_path(path, expected):
assert split_zmq_path(path) == expected assert split_zmq_path(path) == expected
...@@ -711,8 +781,7 @@ def test_split_zmq_path(path, expected): ...@@ -711,8 +781,7 @@ def test_split_zmq_path(path, expected):
"tcp://127.0.0.1", # Missing port "tcp://127.0.0.1", # Missing port
"tcp://[::1]", # Missing port for IPv6 "tcp://[::1]", # Missing port for IPv6
"tcp://:5555", # Missing host "tcp://:5555", # Missing host
] ])
)
def test_split_zmq_path_invalid(invalid_path): def test_split_zmq_path_invalid(invalid_path):
with pytest.raises(ValueError): with pytest.raises(ValueError):
split_zmq_path(invalid_path) split_zmq_path(invalid_path)
...@@ -734,7 +803,8 @@ def test_make_zmq_socket_ipv6(): ...@@ -734,7 +803,8 @@ def test_make_zmq_socket_ipv6():
zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type) zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type)
# Verify that the IPV6 option is set # Verify that the IPV6 option is set
assert zsock.getsockopt(zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses" assert zsock.getsockopt(
zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses"
# Clean up # Clean up
zsock.close() zsock.close()
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import patch from unittest.mock import patch
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os import os
from unittest.mock import patch from unittest.mock import patch
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pickle import pickle
from copy import deepcopy from copy import deepcopy
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional from typing import Any, Optional
...@@ -69,7 +70,8 @@ def _run_incremental_decode(tokenizer, ...@@ -69,7 +70,8 @@ def _run_incremental_decode(tokenizer,
None, None,
0.0, 0.0,
None, None,
cache_salt=None) cache_salt=None,
data_parallel_rank=None)
if fast is None: if fast is None:
detokenizer = IncrementalDetokenizer.from_new_request( detokenizer = IncrementalDetokenizer.from_new_request(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" """
This test file includes some cases where it is inappropriate to This test file includes some cases where it is inappropriate to
only get the `eos_token_id` from the tokenizer as defined by only get the `eos_token_id` from the tokenizer as defined by
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
from mistral_common.protocol.instruct.messages import (AssistantMessage, from mistral_common.protocol.instruct.messages import (AssistantMessage,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
from transformers import AutoTokenizer, PreTrainedTokenizerBase from transformers import AutoTokenizer, PreTrainedTokenizerBase
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
import pytest_asyncio import pytest_asyncio
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import openai import openai
import pytest import pytest
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json import json
from collections.abc import Generator from collections.abc import Generator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment