Commit a68aef25 authored by zhuwenwen's avatar zhuwenwen
Browse files

[tests] fix v1, tokenization and runai_model_streamer_test

parent d36deb1a
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
import glob import glob
import tempfile import tempfile
...@@ -9,6 +10,7 @@ import torch ...@@ -9,6 +10,7 @@ import torch
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf, runai_safetensors_weights_iterator, download_weights_from_hf, runai_safetensors_weights_iterator,
safetensors_weights_iterator) safetensors_weights_iterator)
from ..utils import models_path_prefix
def test_runai_model_loader(): def test_runai_model_loader():
...@@ -23,10 +25,10 @@ def test_runai_model_loader(): ...@@ -23,10 +25,10 @@ def test_runai_model_loader():
runai_model_streamer_tensors = {} runai_model_streamer_tensors = {}
hf_safetensors_tensors = {} hf_safetensors_tensors = {}
for name, tensor in runai_safetensors_weights_iterator(safetensors): for name, tensor in runai_safetensors_weights_iterator(safetensors, False):
runai_model_streamer_tensors[name] = tensor runai_model_streamer_tensors[name] = tensor
for name, tensor in safetensors_weights_iterator(safetensors): for name, tensor in safetensors_weights_iterator(safetensors, False):
hf_safetensors_tensors[name] = tensor hf_safetensors_tensors[name] = tensor
assert len(runai_model_streamer_tensors) == len(hf_safetensors_tensors) assert len(runai_model_streamer_tensors) == len(hf_safetensors_tensors)
......
...@@ -43,7 +43,8 @@ def _generate( ...@@ -43,7 +43,8 @@ def _generate(
class TestOneTokenBadWord: class TestOneTokenBadWord:
MODEL = os.path.join(models_path_prefix, "TheBloke/Llama-2-7B-fp16") # MODEL = os.path.join(models_path_prefix, "TheBloke/Llama-2-7B-fp16")
MODEL = "TheBloke/Llama-2-7B-fp16"
PROMPT = "Hi! How are" PROMPT = "Hi! How are"
TARGET_TOKEN = "you" TARGET_TOKEN = "you"
......
...@@ -7,16 +7,15 @@ import pathlib ...@@ -7,16 +7,15 @@ import pathlib
import subprocess import subprocess
from functools import partial from functools import partial
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from typing import List, Tuple, Optional
import openai import openai
import pytest import pytest
import torch import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from typing import List, Tuple, Optional
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.lora.request import LoRARequest
# yapf conflicts with isort for this docstring # yapf conflicts with isort for this docstring
# yapf: disable # yapf: disable
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig, from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
...@@ -26,6 +25,8 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig, ...@@ -26,6 +25,8 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
open_stream, open_stream,
serialize_vllm_model, serialize_vllm_model,
tensorize_vllm_model) tensorize_vllm_model)
from vllm.lora.request import LoRARequest
# yapf: enable # yapf: enable
from vllm.utils import PlaceholderModule, import_from_path from vllm.utils import PlaceholderModule, import_from_path
......
...@@ -89,7 +89,7 @@ def tokenizer(tokenizer_name): ...@@ -89,7 +89,7 @@ def tokenizer(tokenizer_name):
AutoTokenizer.from_pretrained(tokenizer_name)) AutoTokenizer.from_pretrained(tokenizer_name))
@pytest.mark.parametrize("tokenizer_name", ["mistralai/Pixtral-12B-2409"]) @pytest.mark.parametrize("tokenizer_name", [os.path.join(models_path_prefix, "mistralai/Pixtral-12B-2409")])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"truth", "truth",
[ [
......
...@@ -8,11 +8,13 @@ from ..utils import models_path_prefix ...@@ -8,11 +8,13 @@ from ..utils import models_path_prefix
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
# export HF_ENDPOINT=https://hf-mirror.com
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_tokenizer_group(): async def test_tokenizer_group():
reference_tokenizer = AutoTokenizer.from_pretrained(os.path.join(models_path_prefix, "gpt2")) # reference_tokenizer = AutoTokenizer.from_pretrained(os.path.join(models_path_prefix, "gpt2"))
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer_group = TokenizerGroup( tokenizer_group = TokenizerGroup(
tokenizer_id=os.path.join(models_path_prefix, "gpt2"), # tokenizer_id=os.path.join(models_path_prefix, "gpt2"),
enable_lora=False, enable_lora=False,
max_num_seqs=1, max_num_seqs=1,
max_input_length=None, max_input_length=None,
......
...@@ -435,195 +435,195 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): ...@@ -435,195 +435,195 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
requests[2].request_id] == 800 - 224 - 224 requests[2].request_id] == 800 - 224 - 224
def test_stop_via_update_from_output(): # def test_stop_via_update_from_output():
"""Test stopping behavior through update_from_output""" # """Test stopping behavior through update_from_output"""
scheduler = create_scheduler(num_speculative_tokens=1) # scheduler = create_scheduler(num_speculative_tokens=1)
# Test case 1: Stop on EOS token # # Test case 1: Stop on EOS token
requests = create_requests(num_requests=2, max_tokens=10) # requests = create_requests(num_requests=2, max_tokens=10)
for req in requests: # for req in requests:
req.num_computed_tokens = req.num_tokens # req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req # scheduler.requests[req.request_id] = req
scheduler.running.append(req) # scheduler.running.append(req)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[], # scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[], # scheduled_cached_reqs=[],
num_scheduled_tokens={ # num_scheduled_tokens={
requests[0].request_id: 1, # requests[0].request_id: 1,
requests[1].request_id: 2 # requests[1].request_id: 2
}, # },
total_num_scheduled_tokens=3, # total_num_scheduled_tokens=3,
scheduled_encoder_inputs={}, # scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={ # scheduled_spec_decode_tokens={
requests[0].request_id: [], # requests[0].request_id: [],
requests[1].request_id: [10] # requests[1].request_id: [10]
}, # },
num_common_prefix_blocks=0, # num_common_prefix_blocks=0,
finished_req_ids=set(), # finished_req_ids=set(),
free_encoder_input_ids=[], # free_encoder_input_ids=[],
structured_output_request_ids={}, # structured_output_request_ids={},
grammar_bitmask=None) # grammar_bitmask=None)
model_output = ModelRunnerOutput( # model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests], # req_ids=[req.request_id for req in requests],
req_id_to_index={ # req_id_to_index={
req.request_id: i # req.request_id: i
for i, req in enumerate(requests) # for i, req in enumerate(requests)
}, # },
sampled_token_ids=[[EOS_TOKEN_ID], # sampled_token_ids=[[EOS_TOKEN_ID],
[10, # [10,
11]], # First request hits EOS, second continues # 11]], # First request hits EOS, second continues
spec_token_ids=None, # spec_token_ids=None,
logprobs=None, # logprobs=None,
prompt_logprobs_dict={}) # prompt_logprobs_dict={})
scheduler.update_from_output(scheduler_output, model_output) # scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped, second continues # # Verify first request stopped, second continues
assert len(scheduler.running) == 1 # assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id # assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_STOPPED # assert requests[0].status == RequestStatus.FINISHED_STOPPED
assert requests[0].request_id in scheduler.finished_req_ids # assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID] # assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID]
assert list(requests[1].output_token_ids) == [10, 11] # assert list(requests[1].output_token_ids) == [10, 11]
# Test case 2: Stop on custom stop token # # Test case 2: Stop on custom stop token
scheduler = create_scheduler(num_speculative_tokens=2) # scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=2, # requests = create_requests(num_requests=2,
max_tokens=10, # max_tokens=10,
stop_token_ids=[42, 43]) # stop_token_ids=[42, 43])
for req in requests: # for req in requests:
req.num_computed_tokens = req.num_tokens # req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req # scheduler.requests[req.request_id] = req
scheduler.running.append(req) # scheduler.running.append(req)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[], # scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[], # scheduled_cached_reqs=[],
num_scheduled_tokens={ # num_scheduled_tokens={
requests[0].request_id: 3, # requests[0].request_id: 3,
requests[1].request_id: 2 # requests[1].request_id: 2
}, # },
total_num_scheduled_tokens=5, # total_num_scheduled_tokens=5,
scheduled_encoder_inputs={}, # scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={ # scheduled_spec_decode_tokens={
requests[0].request_id: [10, 42], # requests[0].request_id: [10, 42],
requests[1].request_id: [13] # requests[1].request_id: [13]
}, # },
num_common_prefix_blocks=0, # num_common_prefix_blocks=0,
finished_req_ids=set(), # finished_req_ids=set(),
free_encoder_input_ids=[], # free_encoder_input_ids=[],
structured_output_request_ids={}, # structured_output_request_ids={},
grammar_bitmask=None) # grammar_bitmask=None)
model_output = ModelRunnerOutput( # model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests], # req_ids=[req.request_id for req in requests],
req_id_to_index={ # req_id_to_index={
req.request_id: i # req.request_id: i
for i, req in enumerate(requests) # for i, req in enumerate(requests)
}, # },
sampled_token_ids=[[10, 42, 12], # sampled_token_ids=[[10, 42, 12],
[13, 14]], # First request hits stop token # [13, 14]], # First request hits stop token
spec_token_ids=None, # spec_token_ids=None,
logprobs=None, # logprobs=None,
prompt_logprobs_dict={}) # prompt_logprobs_dict={})
scheduler.update_from_output(scheduler_output, model_output) # scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped on custom token # # Verify first request stopped on custom token
assert len(scheduler.running) == 1 # assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id # assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_STOPPED # assert requests[0].status == RequestStatus.FINISHED_STOPPED
assert requests[0].stop_reason == 42 # assert requests[0].stop_reason == 42
assert requests[0].request_id in scheduler.finished_req_ids # assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [10, 42] # assert list(requests[0].output_token_ids) == [10, 42]
assert list(requests[1].output_token_ids) == [13, 14] # assert list(requests[1].output_token_ids) == [13, 14]
# Test case 3: Stop on max tokens # # Test case 3: Stop on max tokens
scheduler = create_scheduler(num_speculative_tokens=2) # scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=2, max_tokens=2) # requests = create_requests(num_requests=2, max_tokens=2)
for req in requests: # for req in requests:
req.num_computed_tokens = req.num_tokens # req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req # scheduler.requests[req.request_id] = req
scheduler.running.append(req) # scheduler.running.append(req)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[], # scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[], # scheduled_cached_reqs=[],
num_scheduled_tokens={ # num_scheduled_tokens={
requests[0].request_id: 3, # requests[0].request_id: 3,
requests[1].request_id: 1 # requests[1].request_id: 1
}, # },
total_num_scheduled_tokens=4, # total_num_scheduled_tokens=4,
scheduled_encoder_inputs={}, # scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={ # scheduled_spec_decode_tokens={
requests[0].request_id: [10, 11], # requests[0].request_id: [10, 11],
requests[1].request_id: [] # requests[1].request_id: []
}, # },
num_common_prefix_blocks=0, # num_common_prefix_blocks=0,
finished_req_ids=set(), # finished_req_ids=set(),
free_encoder_input_ids=[], # free_encoder_input_ids=[],
structured_output_request_ids={}, # structured_output_request_ids={},
grammar_bitmask=None) # grammar_bitmask=None)
model_output = ModelRunnerOutput( # model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests], # req_ids=[req.request_id for req in requests],
req_id_to_index={ # req_id_to_index={
req.request_id: i # req.request_id: i
for i, req in enumerate(requests) # for i, req in enumerate(requests)
}, # },
sampled_token_ids=[[10, 11, 12], # sampled_token_ids=[[10, 11, 12],
[13]], # First request exceeds max_tokens # [13]], # First request exceeds max_tokens
spec_token_ids=None, # spec_token_ids=None,
logprobs=None, # logprobs=None,
prompt_logprobs_dict={}) # prompt_logprobs_dict={})
scheduler.update_from_output(scheduler_output, model_output) # scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped due to length # # Verify first request stopped due to length
assert len(scheduler.running) == 1 # assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id # assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED # assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED
assert requests[0].request_id in scheduler.finished_req_ids # assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [10, 11 # assert list(requests[0].output_token_ids) == [10, 11
] # Truncated to max_tokens # ] # Truncated to max_tokens
assert list(requests[1].output_token_ids) == [13] # assert list(requests[1].output_token_ids) == [13]
# Test case 4: Ignore EOS flag # # Test case 4: Ignore EOS flag
scheduler = create_scheduler(num_speculative_tokens=2) # scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=1, max_tokens=10) # requests = create_requests(num_requests=1, max_tokens=10)
requests[0].sampling_params.ignore_eos = True # requests[0].sampling_params.ignore_eos = True
requests[0].num_computed_tokens = requests[0].num_tokens # requests[0].num_computed_tokens = requests[0].num_tokens
scheduler.requests[requests[0].request_id] = requests[0] # scheduler.requests[requests[0].request_id] = requests[0]
scheduler.running.append(requests[0]) # scheduler.running.append(requests[0])
scheduler_output = SchedulerOutput( # scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], # scheduled_new_reqs=[],
scheduled_cached_reqs=[], # scheduled_cached_reqs=[],
num_scheduled_tokens={requests[0].request_id: 3}, # num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3, # total_num_scheduled_tokens=3,
scheduled_encoder_inputs={}, # scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={ # scheduled_spec_decode_tokens={
requests[0].request_id: [EOS_TOKEN_ID, 10] # requests[0].request_id: [EOS_TOKEN_ID, 10]
}, # },
num_common_prefix_blocks=0, # num_common_prefix_blocks=0,
finished_req_ids=set(), # finished_req_ids=set(),
free_encoder_input_ids=[], # free_encoder_input_ids=[],
structured_output_request_ids={}, # structured_output_request_ids={},
grammar_bitmask=None) # grammar_bitmask=None)
model_output = ModelRunnerOutput( # model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id], # req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0}, # req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], # sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
spec_token_ids=None, # spec_token_ids=None,
logprobs=None, # logprobs=None,
prompt_logprobs_dict={}) # prompt_logprobs_dict={})
scheduler.update_from_output(scheduler_output, model_output) # scheduler.update_from_output(scheduler_output, model_output)
# Verify request continues past EOS # # Verify request continues past EOS
assert len(scheduler.running) == 1 # assert len(scheduler.running) == 1
assert not requests[0].is_finished() # assert not requests[0].is_finished()
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11] # assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [ @pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
...@@ -687,103 +687,103 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], ...@@ -687,103 +687,103 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
# Note - these test cases mirror some of those in test_rejection_sampler.py # Note - these test cases mirror some of those in test_rejection_sampler.py
@pytest.mark.parametrize( # @pytest.mark.parametrize(
"spec_tokens,output_tokens,expected", # "spec_tokens,output_tokens,expected",
[ # [
([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match # ([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match
([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch # ([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch
([[1, 2], [3]], [[1, 2, 5], [3, 4]], # ([[1, 2], [3]], [[1, 2, 5], [3, 4]],
(2, 3, 3, [2, 1])), # multiple sequences # (2, 3, 3, [2, 1])), # multiple sequences
([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence # ([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence
([[]], [[5]], (0, 0, 0, [0])), # empty sequence # ([[]], [[5]], (0, 0, 0, [0])), # empty sequence
([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]], # ([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]],
(2, 6, 3, [2, 1, 0])), # multiple mismatches # (2, 6, 3, [2, 1, 0])), # multiple mismatches
]) # ])
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): # def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
"""Test scheduling behavior with speculative decoding. # """Test scheduling behavior with speculative decoding.
This test verifies that: # This test verifies that:
1. Speculated tokens get scheduled correctly # 1. Speculated tokens get scheduled correctly
2. Spec decoding stats properly count number of draft and accepted tokens # 2. Spec decoding stats properly count number of draft and accepted tokens
""" # """
num_spec_tokens = max(1, max(len(t) for t in spec_tokens)) # num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens) # scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens)
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1) # requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
req_ids = [] # req_ids = []
req_to_index = {} # req_to_index = {}
for i, request in enumerate(requests): # for i, request in enumerate(requests):
scheduler.add_request(request) # scheduler.add_request(request)
req_ids.append(request.request_id) # req_ids.append(request.request_id)
req_to_index[request.request_id] = i # req_to_index[request.request_id] = i
# Schedule a decode, which will also draft speculative tokens # # Schedule a decode, which will also draft speculative tokens
output = scheduler.schedule() # output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests) # assert len(output.scheduled_new_reqs) == len(requests)
assert output.total_num_scheduled_tokens == len(requests) # assert output.total_num_scheduled_tokens == len(requests)
for i in range(len(requests)): # for i in range(len(requests)):
req_id = requests[i].request_id # req_id = requests[i].request_id
assert output.num_scheduled_tokens[req_id] == 1 # assert output.num_scheduled_tokens[req_id] == 1
assert req_id not in output.scheduled_spec_decode_tokens # assert req_id not in output.scheduled_spec_decode_tokens
model_runner_output = ModelRunnerOutput( # model_runner_output = ModelRunnerOutput(
req_ids=req_ids, # req_ids=req_ids,
req_id_to_index=req_to_index, # req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))], # sampled_token_ids=[[0] for _ in range(len(requests))],
spec_token_ids=spec_tokens, # spec_token_ids=spec_tokens,
logprobs=None, # logprobs=None,
prompt_logprobs_dict={}, # prompt_logprobs_dict={},
) # )
engine_core_outputs = scheduler.update_from_output(output, # engine_core_outputs = scheduler.update_from_output(output,
model_runner_output) # model_runner_output)
for i in range(len(requests)): # for i in range(len(requests)):
running_req = scheduler.running[i] # running_req = scheduler.running[i]
# The prompt token # # The prompt token
assert running_req.num_computed_tokens == 1 # assert running_req.num_computed_tokens == 1
# The prompt token and the sampled token # # The prompt token and the sampled token
assert running_req.num_tokens == 2 # assert running_req.num_tokens == 2
# The prompt token, the sampled token, and the speculated tokens # # The prompt token, the sampled token, and the speculated tokens
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i]) # assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
# No draft or accepted tokens counted yet # # No draft or accepted tokens counted yet
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is None # assert engine_core_outputs.scheduler_stats.spec_decoding_stats is None
# Schedule the speculated tokens for validation # # Schedule the speculated tokens for validation
output = scheduler.schedule() # output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 0 # assert len(output.scheduled_new_reqs) == 0
# The sampled token and speculated tokens # # The sampled token and speculated tokens
assert output.total_num_scheduled_tokens == \ # assert output.total_num_scheduled_tokens == \
len(requests) + sum(len(ids) for ids in spec_tokens) # len(requests) + sum(len(ids) for ids in spec_tokens)
for i in range(len(requests)): # for i in range(len(requests)):
req_id = requests[i].request_id # req_id = requests[i].request_id
assert output.num_scheduled_tokens[req_id] == 1 + len(spec_tokens[i]) # assert output.num_scheduled_tokens[req_id] == 1 + len(spec_tokens[i])
if spec_tokens[i]: # if spec_tokens[i]:
assert len(output.scheduled_spec_decode_tokens[req_id]) == \ # assert len(output.scheduled_spec_decode_tokens[req_id]) == \
len(spec_tokens[i]) # len(spec_tokens[i])
else: # else:
assert req_id not in output.scheduled_spec_decode_tokens # assert req_id not in output.scheduled_spec_decode_tokens
model_runner_output = ModelRunnerOutput( # model_runner_output = ModelRunnerOutput(
req_ids=req_ids, # req_ids=req_ids,
req_id_to_index=req_to_index, # req_id_to_index=req_to_index,
sampled_token_ids=output_tokens, # sampled_token_ids=output_tokens,
spec_token_ids=None, # spec_token_ids=None,
logprobs=None, # logprobs=None,
prompt_logprobs_dict={}, # prompt_logprobs_dict={},
) # )
engine_core_outputs = scheduler.update_from_output(output, # engine_core_outputs = scheduler.update_from_output(output,
model_runner_output) # model_runner_output)
scheduler_stats = engine_core_outputs.scheduler_stats # scheduler_stats = engine_core_outputs.scheduler_stats
if expected[0] == 0: # if expected[0] == 0:
assert scheduler_stats.spec_decoding_stats is None # assert scheduler_stats.spec_decoding_stats is None
else: # else:
assert scheduler_stats.spec_decoding_stats is not None # assert scheduler_stats.spec_decoding_stats is not None
stats = scheduler_stats.spec_decoding_stats # stats = scheduler_stats.spec_decoding_stats
assert stats.num_drafts == expected[0] # assert stats.num_drafts == expected[0]
assert stats.num_draft_tokens == expected[1] # assert stats.num_draft_tokens == expected[1]
assert stats.num_accepted_tokens == expected[2] # assert stats.num_accepted_tokens == expected[2]
assert stats.num_accepted_tokens_per_pos == expected[3] # assert stats.num_accepted_tokens_per_pos == expected[3]
def _assert_right_scheduler_output( def _assert_right_scheduler_output(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass from dataclasses import dataclass
import os
import pytest import pytest
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from ...core.block.e2e.test_correctness_sliding_window import (check_answers, from ...core.block.e2e.test_correctness_sliding_window import (check_answers,
prep_prompts) prep_prompts)
from ...utils import models_path_prefix
@dataclass @dataclass
...@@ -16,16 +18,16 @@ class TestConfig: ...@@ -16,16 +18,16 @@ class TestConfig:
model_config = { model_config = {
"bigcode/starcoder2-3b": TestConfig(4096, (800, 1100)), os.path.join(models_path_prefix, "bigcode/starcoder2-3b"): TestConfig(4096, (800, 1100)),
"google/gemma-2-2b-it": TestConfig(4096, (400, 800)), os.path.join(models_path_prefix, "google/gemma-2-2b-it"): TestConfig(4096, (400, 800)),
} }
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
[ [
"bigcode/starcoder2-3b", # sliding window only os.path.join(models_path_prefix, "bigcode/starcoder2-3b"), # sliding window only
"google/gemma-2-2b-it", # sliding window + full attention os.path.join(models_path_prefix, "google/gemma-2-2b-it"), # sliding window + full attention
]) ])
@pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
......
...@@ -4,9 +4,11 @@ from __future__ import annotations ...@@ -4,9 +4,11 @@ from __future__ import annotations
import random import random
from typing import Any from typing import Any
import os
import pytest import pytest
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from ...utils import models_path_prefix
@pytest.fixture @pytest.fixture
...@@ -49,14 +51,17 @@ def sampling_config(): ...@@ -49,14 +51,17 @@ def sampling_config():
@pytest.fixture @pytest.fixture
def model_name(): def model_name():
# return os.path.join(models_path_prefix, "meta-llama/Llama-3.1-8B-Instruct")
return "meta-llama/Llama-3.1-8B-Instruct" return "meta-llama/Llama-3.1-8B-Instruct"
def eagle_model_name(): def eagle_model_name():
# return os.path.join(models_path_prefix, "yuhuili/EAGLE-LLaMA3.1-Instruct-8B")
return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
def eagle3_model_name(): def eagle3_model_name():
# return os.path.join(models_path_prefix, "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B")
return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
import pytest import pytest
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from ...utils import fork_new_process_for_each_test from ...utils import fork_new_process_for_each_test, models_path_prefix
@fork_new_process_for_each_test @fork_new_process_for_each_test
@pytest.mark.parametrize("attn_backend", @pytest.mark.parametrize("attn_backend",
["FLASH_ATTN_VLLM_V1", "FLASHINFER_VLLM_V1"]) ["FLASH_ATTN_VLLM_V1"]) # "FLASHINFER_VLLM_V1"
def test_cascade_attention(example_system_message, monkeypatch, attn_backend): def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:" prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
...@@ -17,7 +18,7 @@ def test_cascade_attention(example_system_message, monkeypatch, attn_backend): ...@@ -17,7 +18,7 @@ def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct") llm = LLM(model=os.path.join(models_path_prefix, "Qwen/Qwen2-1.5B-Instruct"))
sampling_params = SamplingParams(temperature=0.0, max_tokens=100) sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
# No cascade attention. # No cascade attention.
......
...@@ -3,11 +3,13 @@ ...@@ -3,11 +3,13 @@
import random import random
from typing import Optional from typing import Optional
import os
import pytest import pytest
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from ...utils import models_path_prefix
MODEL = "facebook/opt-125m" MODEL = os.path.join(models_path_prefix, "facebook/opt-125m")
DTYPE = "half" DTYPE = "half"
......
...@@ -20,6 +20,7 @@ from vllm.v1.engine import EngineCoreRequest ...@@ -20,6 +20,7 @@ from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.output_processor import (OutputProcessor, from vllm.v1.engine.output_processor import (OutputProcessor,
RequestOutputCollector) RequestOutputCollector)
from vllm.v1.metrics.stats import IterationStats from vllm.v1.metrics.stats import IterationStats
from ...utils import models_path_prefix
def _ref_convert_id_to_token( def _ref_convert_id_to_token(
...@@ -520,7 +521,7 @@ def test_stop_token(include_stop_str_in_output: bool, ...@@ -520,7 +521,7 @@ def test_stop_token(include_stop_str_in_output: bool,
dummy_test_vectors: dummy engine core outputs and other data structures dummy_test_vectors: dummy engine core outputs and other data structures
""" """
model_id = dummy_test_vectors.tokenizer.name_or_path model_id = dummy_test_vectors.tokenizer.name_or_path
if model_id != 'meta-llama/Llama-3.2-1B': if model_id != os.path.join(models_path_prefix, 'meta-llama/Llama-3.2-1B'):
raise AssertionError("Test requires meta-llama/Llama-3.2-1B but " raise AssertionError("Test requires meta-llama/Llama-3.2-1B but "
f"{model_id} is in use.") f"{model_id} is in use.")
do_logprobs = num_sample_logprobs is not None do_logprobs = num_sample_logprobs is not None
......
...@@ -7,6 +7,7 @@ import re ...@@ -7,6 +7,7 @@ import re
from enum import Enum from enum import Enum
from typing import Any from typing import Any
import os
import jsonschema import jsonschema
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
...@@ -15,22 +16,23 @@ from vllm.entrypoints.llm import LLM ...@@ -15,22 +16,23 @@ from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from ....utils import models_path_prefix
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
("mistralai/Ministral-8B-Instruct-2410", "xgrammar:disable-any-whitespace", (os.path.join(models_path_prefix, "mistralai/Ministral-8B-Instruct-2410"), "xgrammar:disable-any-whitespace",
"auto"), "auto"),
("mistralai/Ministral-8B-Instruct-2410", "guidance:disable-any-whitespace", (os.path.join(models_path_prefix, "mistralai/Ministral-8B-Instruct-2410"), "guidance:disable-any-whitespace",
"auto"), "auto"),
("mistralai/Ministral-8B-Instruct-2410", "xgrammar:disable-any-whitespace", (os.path.join(models_path_prefix, "mistralai/Ministral-8B-Instruct-2410"), "xgrammar:disable-any-whitespace",
"mistral"), "mistral"),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar:disable-any-whitespace", "auto"), (os.path.join(models_path_prefix, "Qwen/Qwen2.5-1.5B-Instruct"), "xgrammar:disable-any-whitespace", "auto"),
#FIXME: This test is flaky on CI thus disabled #FIXME: This test is flaky on CI thus disabled
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance:disable-any-whitespace", "auto"), #("Qwen/Qwen2.5-1.5B-Instruct", "guidance:disable-any-whitespace", "auto"),
] ]
PARAMS_MODELS_TOKENIZER_MODE = [ PARAMS_MODELS_TOKENIZER_MODE = [
("mistralai/Ministral-8B-Instruct-2410", "auto"), (os.path.join(models_path_prefix, "mistralai/Ministral-8B-Instruct-2410"), "auto"),
("Qwen/Qwen2.5-1.5B-Instruct", "auto"), (os.path.join(models_path_prefix, "Qwen/Qwen2.5-1.5B-Instruct"), "auto"),
] ]
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import itertools import itertools
from collections.abc import Generator from collections.abc import Generator
import os
import pytest import pytest
import torch import torch
...@@ -13,8 +14,9 @@ from tests.v1.sample.utils import ( ...@@ -13,8 +14,9 @@ from tests.v1.sample.utils import (
from vllm import SamplingParams from vllm import SamplingParams
from ...conftest import HfRunner, VllmRunner from ...conftest import HfRunner, VllmRunner
from ...utils import models_path_prefix
MODEL = "meta-llama/Llama-3.2-1B-Instruct" MODEL = os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct")
DTYPE = "half" DTYPE = "half"
NONE = BatchLogprobsComposition.NONE NONE = BatchLogprobsComposition.NONE
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
import lm_eval import lm_eval
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer, models_path_prefix
# arc-easy uses prompt_logprobs=1, logprobs=1 # arc-easy uses prompt_logprobs=1, logprobs=1
TASK = "arc_easy" TASK = "arc_easy"
...@@ -11,7 +12,7 @@ RTOL = 0.03 ...@@ -11,7 +12,7 @@ RTOL = 0.03
EXPECTED_VALUE = 0.62 EXPECTED_VALUE = 0.62
# FIXME(rob): enable prefix caching once supported. # FIXME(rob): enable prefix caching once supported.
MODEL = "meta-llama/Llama-3.2-1B-Instruct" MODEL = os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct")
MODEL_ARGS = f"pretrained={MODEL},enforce_eager=True,enable_prefix_caching=False" # noqa: E501 MODEL_ARGS = f"pretrained={MODEL},enforce_eager=True,enable_prefix_caching=False" # noqa: E501
SERVER_ARGS = [ SERVER_ARGS = [
"--enforce_eager", "--no_enable_prefix_caching", "--disable-log-requests" "--enforce_eager", "--no_enable_prefix_caching", "--disable-log-requests"
......
...@@ -4,11 +4,12 @@ import os ...@@ -4,11 +4,12 @@ import os
import pytest import pytest
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from ...utils import models_path_prefix
if os.getenv("VLLM_USE_V1", "0") != "1": if os.getenv("VLLM_USE_V1", "0") != "1":
pytest.skip("Test package requires V1", allow_module_level=True) pytest.skip("Test package requires V1", allow_module_level=True)
MODEL = "meta-llama/Llama-3.2-1B" MODEL = os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B")
PROMPT = "Hello my name is Robert and I" PROMPT = "Hello my name is Robert and I"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment