Unverified Commit 2f8844ba authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

Re-enable the 80 char line width limit (#3305)

parent 4b59f00e
...@@ -9,6 +9,10 @@ requires = [ ...@@ -9,6 +9,10 @@ requires = [
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[tool.ruff]
# Allow lines to be as long as 80.
line-length = 80
[tool.ruff.lint] [tool.ruff.lint]
select = [ select = [
# pycodestyle # pycodestyle
...@@ -29,8 +33,6 @@ ignore = [ ...@@ -29,8 +33,6 @@ ignore = [
"F405", "F403", "F405", "F403",
# lambda expression assignment # lambda expression assignment
"E731", "E731",
# line too long, handled by black formatting
"E501",
# .strip() with multi-character strings # .strip() with multi-character strings
"B005", "B005",
# Loop control variable not used within loop body # Loop control variable not used within loop body
......
...@@ -142,8 +142,8 @@ def get_pytorch_rocm_arch() -> Set[str]: ...@@ -142,8 +142,8 @@ def get_pytorch_rocm_arch() -> Set[str]:
# If we don't have PYTORCH_ROCM_ARCH specified pull the list from rocm_agent_enumerator # If we don't have PYTORCH_ROCM_ARCH specified pull the list from rocm_agent_enumerator
if env_arch_list is None: if env_arch_list is None:
command = "rocm_agent_enumerator" command = "rocm_agent_enumerator"
env_arch_list = subprocess.check_output([command]).decode('utf-8')\ env_arch_list = (subprocess.check_output(
.strip().replace("\n", ";") [command]).decode('utf-8').strip().replace("\n", ";"))
arch_source_str = "rocm_agent_enumerator" arch_source_str = "rocm_agent_enumerator"
else: else:
arch_source_str = "PYTORCH_ROCM_ARCH env variable" arch_source_str = "PYTORCH_ROCM_ARCH env variable"
......
...@@ -73,7 +73,7 @@ def test_load_chat_template(): ...@@ -73,7 +73,7 @@ def test_load_chat_template():
assert template_content is not None assert template_content is not None
# Hard coded value for template_chatml.jinja # Hard coded value for template_chatml.jinja
assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %} assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501
def test_no_load_chat_template(): def test_no_load_chat_template():
...@@ -117,4 +117,6 @@ async def test_get_gen_prompt(model, template, add_generation_prompt, ...@@ -117,4 +117,6 @@ async def test_get_gen_prompt(model, template, add_generation_prompt,
add_generation_prompt=mock_request.add_generation_prompt) add_generation_prompt=mock_request.add_generation_prompt)
# Test assertion # Test assertion
assert result == expected_output, f"The generated prompt does not match the expected output for model {model} and template {template}" assert result == expected_output, (
f"The generated prompt does not match the expected output for "
f"model {model} and template {template}")
...@@ -4,7 +4,8 @@ from typing import List ...@@ -4,7 +4,8 @@ from typing import List
from vllm import SamplingParams from vllm import SamplingParams
from vllm.block import PhysicalTokenBlock from vllm.block import PhysicalTokenBlock
from vllm.core.block_manager import BlockAllocator, BlockSpaceManager, AllocStatus from vllm.core.block_manager import (BlockAllocator, BlockSpaceManager,
AllocStatus)
from vllm.utils import Device from vllm.utils import Device
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus, Logprob from vllm.sequence import Sequence, SequenceGroup, SequenceStatus, Logprob
......
...@@ -46,8 +46,8 @@ TEST_SCHEMA = { ...@@ -46,8 +46,8 @@ TEST_SCHEMA = {
"required": ["name", "age", "skills", "work history"] "required": ["name", "age", "skills", "work history"]
} }
TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \ TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)" r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
def test_guided_logits_processors(): def test_guided_logits_processors():
......
...@@ -5,9 +5,12 @@ import time ...@@ -5,9 +5,12 @@ import time
import sys import sys
import pytest import pytest
import requests import requests
import ray # using Ray for overall ease of process management, parallel requests, and debugging. # using Ray for overall ease of process management, parallel requests,
# and debugging.
import ray
import openai # use the official client for correctness check import openai # use the official client for correctness check
from huggingface_hub import snapshot_download # downloading lora to test lora requests # downloading lora to test lora requests
from huggingface_hub import snapshot_download
# imports for guided decoding tests # imports for guided decoding tests
import json import json
...@@ -17,8 +20,11 @@ import re ...@@ -17,8 +20,11 @@ import re
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here # any model with a chat template should work here
LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"
TEST_SCHEMA = { TEST_SCHEMA = {
"type": "object", "type": "object",
...@@ -59,8 +65,8 @@ TEST_SCHEMA = { ...@@ -59,8 +65,8 @@ TEST_SCHEMA = {
"required": ["name", "age", "skills", "work history"] "required": ["name", "age", "skills", "work history"]
} }
TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \ TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)" r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
TEST_CHOICE = [ TEST_CHOICE = [
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby", "Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby",
...@@ -120,8 +126,9 @@ def server(zephyr_lora_files): ...@@ -120,8 +126,9 @@ def server(zephyr_lora_files):
server_runner = ServerRunner.remote([ server_runner = ServerRunner.remote([
"--model", "--model",
MODEL_NAME, MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
"bfloat16", # use half precision for speed and memory savings in CI environment "bfloat16",
"--max-model-len", "--max-model-len",
"8192", "8192",
"--enforce-eager", "--enforce-eager",
...@@ -392,7 +399,8 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI, ...@@ -392,7 +399,8 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI,
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
extra_body=dict( extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but not necessary for official client. # NOTE: this has to be true for n > 1 in vLLM, but not necessary
# for official client.
use_beam_search=True), use_beam_search=True),
) )
assert len(batch.choices) == 4 assert len(batch.choices) == 4
...@@ -469,8 +477,8 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI): ...@@ -469,8 +477,8 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
async def test_guided_json_completion(server, client: openai.AsyncOpenAI): async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
completion = await client.completions.create( completion = await client.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
prompt= prompt=f"Give an example JSON for an employee profile "
f"Give an example JSON for an employee profile that fits this schema: {TEST_SCHEMA}", f"that fits this schema: {TEST_SCHEMA}",
n=3, n=3,
temperature=1.0, temperature=1.0,
max_tokens=500, max_tokens=500,
...@@ -489,9 +497,11 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI): ...@@ -489,9 +497,11 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
}, { }, {
"role": "user", "role":
"content": "Give an example JSON for an employee profile that " + \ "user",
f"fits this schema: {TEST_SCHEMA}" "content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}"
}] }]
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
......
...@@ -57,7 +57,8 @@ def test_fused_moe( ...@@ -57,7 +57,8 @@ def test_fused_moe(
[torch.float32, torch.float16, torch.bfloat16]) [torch.float32, torch.float16, torch.bfloat16])
@torch.inference_mode() @torch.inference_mode()
def test_mixtral_moe(dtype: torch.dtype): def test_mixtral_moe(dtype: torch.dtype):
"Make sure our Mixtral MoE implementation agrees with the one from huggingface." """Make sure our Mixtral MoE implementation agrees with the one from
huggingface."""
# Instantiate our and huggingface's MoE blocks # Instantiate our and huggingface's MoE blocks
config = MixtralConfig() config = MixtralConfig()
......
...@@ -114,7 +114,8 @@ def test_contexted_kv_attention( ...@@ -114,7 +114,8 @@ def test_contexted_kv_attention(
v_cache = v_cache.view(-1, block_size, num_kv_heads, v_cache = v_cache.view(-1, block_size, num_kv_heads,
head_size).permute(0, 2, 3, 1).contiguous() head_size).permute(0, 2, 3, 1).contiguous()
# Warm up the Triton kernel by calling it once before actually measuring generation time # Warm up the Triton kernel by calling it once before actually measuring
# generation time
context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table,
b_start_loc, b_seq_len, b_ctx_len, max_input_len) b_start_loc, b_seq_len, b_ctx_len, max_input_len)
torch.cuda.synchronize() torch.cuda.synchronize()
......
...@@ -11,9 +11,9 @@ from .conftest import cleanup ...@@ -11,9 +11,9 @@ from .conftest import cleanup
MODEL_PATH = "Felladrin/Llama-68M-Chat-v1" MODEL_PATH = "Felladrin/Llama-68M-Chat-v1"
PROMPTS = [ PROMPTS = [
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]", "[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]", # noqa: E501
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]", "[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]", # noqa: E501
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nBioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. [/user] [assistant]", "[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nBioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. [/user] [assistant]", # noqa: E501
] ]
......
...@@ -17,14 +17,16 @@ from vllm.lora.layers import ( ...@@ -17,14 +17,16 @@ from vllm.lora.layers import (
LoRAMapping, LoRAMapping,
BaseLayerWithLoRA, BaseLayerWithLoRA,
) )
from vllm.lora.models import LoRALayerWeights, convert_mapping, PackedLoRALayerWeights from vllm.lora.models import (LoRALayerWeights, convert_mapping,
PackedLoRALayerWeights)
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
RowParallelLinear, RowParallelLinear,
QKVParallelLinear) QKVParallelLinear)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from .utils import DummyLoRAManager from .utils import DummyLoRAManager
...@@ -258,7 +260,8 @@ def test_embeddings(dist_init, num_loras, device) -> None: ...@@ -258,7 +260,8 @@ def test_embeddings(dist_init, num_loras, device) -> None:
@torch.inference_mode() @torch.inference_mode()
# @pytest.mark.skip(reason="Fails when loras are in any slot other than the first.") # @pytest.mark.skip(
# reason="Fails when loras are in any slot other than the first.")
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
...@@ -674,9 +677,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None: ...@@ -674,9 +677,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
result = linear(input_)[0] result = linear(input_)[0]
subloras = sublora_dict[lora_id] subloras = sublora_dict[lora_id]
for i, sublora in enumerate(subloras): for i, sublora in enumerate(subloras):
result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * ( result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] *
i + 1 (i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b *
)] += input_ @ sublora.lora_a @ sublora.lora_b * sublora.scaling sublora.scaling)
expected_results.append(result) expected_results.append(result)
expected_result = torch.cat(expected_results) expected_result = torch.cat(expected_results)
......
...@@ -10,12 +10,12 @@ MODEL_PATH = "meta-llama/Llama-2-7b-hf" ...@@ -10,12 +10,12 @@ MODEL_PATH = "meta-llama/Llama-2-7b-hf"
def do_sample(llm, lora_path: str, lora_id: int): def do_sample(llm, lora_path: str, lora_id: int):
prompts = [ prompts = [
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501
] ]
sampling_params = vllm.SamplingParams(temperature=0, sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=256, max_tokens=256,
...@@ -48,20 +48,20 @@ def test_llama_lora(sql_lora_files, tp_size): ...@@ -48,20 +48,20 @@ def test_llama_lora(sql_lora_files, tp_size):
tensor_parallel_size=tp_size) tensor_parallel_size=tp_size)
expected_no_lora_output = [ expected_no_lora_output = [
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", # noqa: E501
"\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", "\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", # noqa: E501
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", # noqa: E501
] ]
expected_lora_output = [ expected_lora_output = [
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
" SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501
" SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", " SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501
" SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", " SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501
" SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", # noqa: E501
" SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501
] ]
print("lora adapter created") print("lora adapter created")
...@@ -121,7 +121,8 @@ def test_llama_tensor_parallel_equality(sql_lora_files): ...@@ -121,7 +121,8 @@ def test_llama_tensor_parallel_equality(sql_lora_files):
def test_llama_lora_warmup(sql_lora_files): def test_llama_lora_warmup(sql_lora_files):
"""Test that the LLM initialization works with a warmup LORA path and is more conservative""" """Test that the LLM initialization works with a warmup LORA path and
is more conservative"""
@ray.remote(num_gpus=1) @ray.remote(num_gpus=1)
def get_num_gpu_blocks_lora(): def get_num_gpu_blocks_lora():
...@@ -132,13 +133,15 @@ def test_llama_lora_warmup(sql_lora_files): ...@@ -132,13 +133,15 @@ def test_llama_lora_warmup(sql_lora_files):
@ray.remote(num_gpus=1) @ray.remote(num_gpus=1)
def get_num_gpu_blocks_no_lora(): def get_num_gpu_blocks_no_lora():
llm = vllm.LLM(MODEL_PATH, max_num_seqs=16) llm = vllm.LLM(MODEL_PATH, max_num_seqs=16)
num_gpu_blocks_no_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks num_gpu_blocks_no_lora_warmup = (
llm.llm_engine.cache_config.num_gpu_blocks)
return num_gpu_blocks_no_lora_warmup return num_gpu_blocks_no_lora_warmup
num_gpu_blocks_lora_warmup = ray.get(get_num_gpu_blocks_lora.remote()) num_gpu_blocks_lora_warmup = ray.get(get_num_gpu_blocks_lora.remote())
num_gpu_blocks_no_lora_warmup = ray.get( num_gpu_blocks_no_lora_warmup = ray.get(
get_num_gpu_blocks_no_lora.remote()) get_num_gpu_blocks_no_lora.remote())
assert num_gpu_blocks_lora_warmup < num_gpu_blocks_no_lora_warmup, ( assert num_gpu_blocks_lora_warmup < num_gpu_blocks_no_lora_warmup, (
"The warmup with lora should be more" "The warmup with lora should be more "
" conservative than without lora, therefore the number of memory blocks for the KV cache should be " "conservative than without lora, therefore the number of "
"memory blocks for the KV cache should be "
"less when using lora than when not using lora") "less when using lora than when not using lora")
...@@ -9,9 +9,9 @@ MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1" ...@@ -9,9 +9,9 @@ MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1"
def do_sample(llm, lora_path: str, lora_id: int): def do_sample(llm, lora_path: str, lora_id: int):
prompts = [ prompts = [
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]", "[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]", # noqa: E501
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]", "[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]", # noqa: E501
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nBioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. [/user] [assistant]", "[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nBioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. [/user] [assistant]", # noqa: E501
] ]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256) sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256)
outputs = llm.generate( outputs = llm.generate(
...@@ -42,9 +42,9 @@ def test_mixtral_lora(mixtral_lora_files, tp_size): ...@@ -42,9 +42,9 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
worker_use_ray=True) worker_use_ray=True)
expected_lora_output = [ expected_lora_output = [
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", "give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", # noqa: E501
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", "give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", # noqa: E501
"inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes])", "inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes])", # noqa: E501
] ]
assert do_sample(llm, mixtral_lora_files, assert do_sample(llm, mixtral_lora_files,
......
...@@ -21,7 +21,8 @@ def test_metric_counter_prompt_tokens( ...@@ -21,7 +21,8 @@ def test_metric_counter_prompt_tokens(
gpu_memory_utilization=0.4) gpu_memory_utilization=0.4)
tokenizer = vllm_model.model.get_tokenizer() tokenizer = vllm_model.model.get_tokenizer()
prompt_token_counts = [len(tokenizer.encode(p)) for p in example_prompts] prompt_token_counts = [len(tokenizer.encode(p)) for p in example_prompts]
# This test needs at least 2 prompts in a batch of different lengths to verify their token count is correct despite padding. # This test needs at least 2 prompts in a batch of different lengths to
# verify their token count is correct despite padding.
assert len(example_prompts) > 1, "at least 2 prompts are required" assert len(example_prompts) > 1, "at least 2 prompts are required"
assert prompt_token_counts[0] != prompt_token_counts[1], ( assert prompt_token_counts[0] != prompt_token_counts[1], (
"prompts of different lengths are required") "prompts of different lengths are required")
...@@ -33,8 +34,8 @@ def test_metric_counter_prompt_tokens( ...@@ -33,8 +34,8 @@ def test_metric_counter_prompt_tokens(
**stat_logger.labels)._value.get() **stat_logger.labels)._value.get()
assert vllm_prompt_token_count == metric_count, ( assert vllm_prompt_token_count == metric_count, (
f"prompt token count: {vllm_prompt_token_count!r}\nmetric: {metric_count!r}" f"prompt token count: {vllm_prompt_token_count!r}\n"
) f"metric: {metric_count!r}")
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
...@@ -60,9 +61,10 @@ def test_metric_counter_generation_tokens( ...@@ -60,9 +61,10 @@ def test_metric_counter_generation_tokens(
for i in range(len(example_prompts)): for i in range(len(example_prompts)):
vllm_output_ids, vllm_output_str = vllm_outputs[i] vllm_output_ids, vllm_output_str = vllm_outputs[i]
prompt_ids = tokenizer.encode(example_prompts[i]) prompt_ids = tokenizer.encode(example_prompts[i])
# vllm_output_ids contains both prompt tokens and generation tokens. We're interested only in the count of the generation tokens. # vllm_output_ids contains both prompt tokens and generation tokens.
# We're interested only in the count of the generation tokens.
vllm_generation_count += len(vllm_output_ids) - len(prompt_ids) vllm_generation_count += len(vllm_output_ids) - len(prompt_ids)
assert vllm_generation_count == metric_count, ( assert vllm_generation_count == metric_count, (
f"generation token count: {vllm_generation_count!r}\nmetric: {metric_count!r}" f"generation token count: {vllm_generation_count!r}\n"
) f"metric: {metric_count!r}")
"""Compare the outputs of a GPTQ model to a Marlin model. """Compare the outputs of a GPTQ model to a Marlin model.
Note: GPTQ and Marlin do not have bitwise correctness. Note: GPTQ and Marlin do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the As a result, in this test, we just confirm that the top selected tokens of the
Marlin/GPTQ models are in the top 3 selections of each other. Marlin/GPTQ models are in the top 3 selections of each other.
Note: Marlin internally uses locks to synchronize the threads. This can Note: Marlin internally uses locks to synchronize the threads. This can
...@@ -14,7 +14,8 @@ Run `pytest tests/models/test_marlin.py --forked`. ...@@ -14,7 +14,8 @@ Run `pytest tests/models/test_marlin.py --forked`.
import pytest import pytest
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
from vllm.model_executor.layers.quantization import _QUANTIZATION_CONFIG_REGISTRY from vllm.model_executor.layers.quantization import (
_QUANTIZATION_CONFIG_REGISTRY)
capability = torch.cuda.get_device_capability() capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
...@@ -87,11 +88,11 @@ def test_models( ...@@ -87,11 +88,11 @@ def test_models(
if marlin_output_id != gptq_output_id: if marlin_output_id != gptq_output_id:
# Each predicted token must be in top 5 of the other's # Each predicted token must be in top 5 of the other's
assert gptq_output_id in marlin_logprobs[idx], ( assert gptq_output_id in marlin_logprobs[idx], (
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}" f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\n"
) f"Marlin:\t{marlin_output_str!r}")
assert marlin_output_id in gptq_logprobs[idx], ( assert marlin_output_id in gptq_logprobs[idx], (
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}" f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\n"
) f"Marlin:\t{marlin_output_str!r}")
# Break out since sequences will now diverge. # Break out since sequences will now diverge.
break break
...@@ -20,20 +20,23 @@ def test_block_allocator( ...@@ -20,20 +20,23 @@ def test_block_allocator(
num_blocks, num_blocks,
enable_caching=True) enable_caching=True)
# Allocate two PysicalTokenBlocks with the same hash and check that they are the same PhysicalTokenBlock # Allocate two PysicalTokenBlocks with the same hash and check
# that they are the same PhysicalTokenBlock
first_block = block_allocator.allocate(block_hash, 0) first_block = block_allocator.allocate(block_hash, 0)
second_block = block_allocator.allocate(block_hash, 0) second_block = block_allocator.allocate(block_hash, 0)
assert (first_block == second_block) assert (first_block == second_block)
assert (second_block.ref_count == 2) assert (second_block.ref_count == 2)
# Free the first_block and confirm that the ref_count is correctly decremented on the second block # Free the first_block and confirm that the ref_count is correctly
# decremented on the second block
block_allocator.free(first_block) block_allocator.free(first_block)
assert (second_block.ref_count == 1) assert (second_block.ref_count == 1)
# Free the second block # Free the second block
block_allocator.free(second_block) block_allocator.free(second_block)
# Reallocate the first block and confirm that, even after the block had its ref_count go to 0, we still get the same block back # Reallocate the first block and confirm that, even after the block
# had its ref_count go to 0, we still get the same block back
first_block = block_allocator.allocate(block_hash, 0) first_block = block_allocator.allocate(block_hash, 0)
assert (first_block == second_block) assert (first_block == second_block)
assert (first_block.block_hash == block_hash) assert (first_block.block_hash == block_hash)
...@@ -56,7 +59,8 @@ def test_eviction(num_blocks: int, ): ...@@ -56,7 +59,8 @@ def test_eviction(num_blocks: int, ):
for block in blocks: for block in blocks:
block_allocator.free(block) block_allocator.free(block)
# Allocate a new block and confirm that it's the first block freed. I.E The Least Recently Used block # Allocate a new block and confirm that it's the first block freed.
# I.E The Least Recently Used block
new_block_hash = block_size new_block_hash = block_size
new_block = block_allocator.allocate(new_block_hash, 0) new_block = block_allocator.allocate(new_block_hash, 0)
assert (new_block == blocks[0]) assert (new_block == blocks[0])
...@@ -68,7 +72,8 @@ def test_eviction(num_blocks: int, ): ...@@ -68,7 +72,8 @@ def test_eviction(num_blocks: int, ):
assert (realloc_block == blocks[realloc_block_hash]) assert (realloc_block == blocks[realloc_block_hash])
assert (realloc_block.block_hash == realloc_block_hash) assert (realloc_block.block_hash == realloc_block_hash)
# Allocate a new block and confirm that it's not the realloc_block, since the realloc_block shouldn't be in the free list # Allocate a new block and confirm that it's not the realloc_block,
# since the realloc_block shouldn't be in the free list
new_block_hash = block_size + 1 new_block_hash = block_size + 1
new_block = block_allocator.allocate(new_block_hash, 0) new_block = block_allocator.allocate(new_block_hash, 0)
assert (realloc_block != new_block) assert (realloc_block != new_block)
......
...@@ -70,8 +70,8 @@ def test_get_prompt_logprobs( ...@@ -70,8 +70,8 @@ def test_get_prompt_logprobs(
hf_logprob[i][-1][token_id].item(), hf_logprob[i][-1][token_id].item(),
atol=1e-2, atol=1e-2,
rtol=1e-2) rtol=1e-2)
assert isinstance(sample_logprob.decoded_token, str), \ assert isinstance(sample_logprob.decoded_token, str), (
("The token should be decoded by the time it is returned " "The token should be decoded by the time it is returned "
" to the user.") " to the user.")
......
...@@ -255,9 +255,10 @@ def test_sampler_mixed(seed: int, device: str): ...@@ -255,9 +255,10 @@ def test_sampler_mixed(seed: int, device: str):
if metadata.sampling_params.use_beam_search: if metadata.sampling_params.use_beam_search:
continue continue
if metadata.sampling_params.seed is not None \ if (metadata.sampling_params.seed is not None
and expected_tokens[i] is None: and expected_tokens[i] is None):
# Record seeded random result to compare with results of second invocation # Record seeded random result to compare with results of
# second invocation
expected_tokens[i] = [ expected_tokens[i] = [
nth_output.output_token nth_output.output_token
for nth_output in sequence_output.samples for nth_output in sequence_output.samples
...@@ -265,11 +266,13 @@ def test_sampler_mixed(seed: int, device: str): ...@@ -265,11 +266,13 @@ def test_sampler_mixed(seed: int, device: str):
continue continue
for n, nth_output in enumerate(sequence_output.samples): for n, nth_output in enumerate(sequence_output.samples):
if metadata.sampling_params.temperature == 0 or metadata.sampling_params.seed is not None: if (metadata.sampling_params.temperature == 0
or metadata.sampling_params.seed is not None):
# Ensure exact matches for greedy or random with seed # Ensure exact matches for greedy or random with seed
assert nth_output.output_token == expected_tokens[i][n] assert nth_output.output_token == expected_tokens[i][n]
else: else:
# For non-seeded random check that one of the high-logit tokens were chosen # For non-seeded random check that one of the high-logit
# tokens were chosen
assert nth_output.output_token in expected_tokens[i] assert nth_output.output_token in expected_tokens[i]
# Test batch # Test batch
...@@ -284,8 +287,8 @@ def test_sampler_mixed(seed: int, device: str): ...@@ -284,8 +287,8 @@ def test_sampler_mixed(seed: int, device: str):
input_tensor.data = input_tensor.index_select(0, target_index) input_tensor.data = input_tensor.index_select(0, target_index)
fake_logits.data = fake_logits.index_select(0, target_index) fake_logits.data = fake_logits.index_select(0, target_index)
# This time, results of seeded random samples will be compared with the corresponding # This time, results of seeded random samples will be compared with
# sample in the pre-shuffled batch # the corresponding sample in the pre-shuffled batch
test_sampling(model_runner) test_sampling(model_runner)
del model_runner del model_runner
......
...@@ -150,8 +150,10 @@ def test_initial_metrics_has_correct_values(has_data: bool): ...@@ -150,8 +150,10 @@ def test_initial_metrics_has_correct_values(has_data: bool):
assert metrics.emitted_tokens == num_emitted_tokens assert metrics.emitted_tokens == num_emitted_tokens
if has_data: if has_data:
assert metrics.draft_acceptance_rate == num_accepted_tokens / num_draft_tokens assert (metrics.draft_acceptance_rate == num_accepted_tokens /
assert metrics.system_efficiency == num_emitted_tokens / num_possible_tokens num_draft_tokens)
assert (metrics.system_efficiency == num_emitted_tokens /
num_possible_tokens)
else: else:
assert math.isnan(metrics.draft_acceptance_rate) assert math.isnan(metrics.draft_acceptance_rate)
assert math.isnan(metrics.system_efficiency) assert math.isnan(metrics.system_efficiency)
...@@ -3,7 +3,8 @@ import random ...@@ -3,7 +3,8 @@ import random
import pytest import pytest
from unittest.mock import MagicMock from unittest.mock import MagicMock
from vllm.spec_decode.multi_step_worker import MultiStepWorker, DraftModelTop1Proposer from vllm.spec_decode.multi_step_worker import (MultiStepWorker,
DraftModelTop1Proposer)
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
......
...@@ -4,12 +4,15 @@ import pytest ...@@ -4,12 +4,15 @@ import pytest
from unittest.mock import MagicMock from unittest.mock import MagicMock
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker, split_num_cache_blocks_evenly from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
split_num_cache_blocks_evenly)
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from .utils import mock_worker, create_batch, ExecuteModelData, create_sampler_output_list from .utils import (mock_worker, create_batch, ExecuteModelData,
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics, AsyncMetricsCollector create_sampler_output_list)
from vllm.spec_decode.metrics import (SpecDecodeWorkerMetrics,
AsyncMetricsCollector)
@pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('k', [1, 2, 6])
...@@ -391,13 +394,15 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): ...@@ -391,13 +394,15 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
mock_rejsample_metrics = MagicMock( mock_rejsample_metrics = MagicMock(
spec=SpecDecodeWorkerMetrics) if returns_metrics else None spec=SpecDecodeWorkerMetrics) if returns_metrics else None
metrics_collector.maybe_collect_rejsample_metrics.return_value = mock_rejsample_metrics metrics_collector.maybe_collect_rejsample_metrics.return_value = (
mock_rejsample_metrics)
output = worker.execute_model(**execute_model_data.to_dict(), output = worker.execute_model(**execute_model_data.to_dict(),
num_spec_tokens=k) num_spec_tokens=k)
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
call_args_list = metrics_collector.maybe_collect_rejsample_metrics.call_args_list call_args_list = (
metrics_collector.maybe_collect_rejsample_metrics.call_args_list)
assert len(call_args_list) == 1 assert len(call_args_list) == 1
args, kwargs = call_args_list[0] args, kwargs = call_args_list[0]
assert args[0] == k or kwargs.get('k', -1) == k assert args[0] == k or kwargs.get('k', -1) == k
...@@ -547,7 +552,8 @@ def test_profile_num_available_blocks(available_gpu_blocks: int, ...@@ -547,7 +552,8 @@ def test_profile_num_available_blocks(available_gpu_blocks: int,
target_worker.profile_num_available_blocks.return_value = ( target_worker.profile_num_available_blocks.return_value = (
available_gpu_blocks, available_cpu_blocks) available_gpu_blocks, available_cpu_blocks)
target_worker.get_cache_block_size_bytes.return_value = target_cache_block_size_bytes target_worker.get_cache_block_size_bytes.return_value = (
target_cache_block_size_bytes)
draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment