"lib/llm/src/vscode:/vscode.git/clone" did not exist on "84e71e27d36e3db7168e673137ac9d6d10537efe"
Commit 31f6b24f authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/v0.8.2' into v0.8.2-ori

parents 89d1dd57 25f560a6
# SPDX-License-Identifier: Apache-2.0
import Cython.Compiler.Options
from Cython.Build import cythonize
from setuptools import setup
Cython.Compiler.Options.annotate = True
infiles = []
infiles += [
"vllm/engine/llm_engine.py",
"vllm/transformers_utils/detokenizer.py",
"vllm/engine/output_processor/single_step.py",
"vllm/outputs.py",
"vllm/engine/output_processor/stop_checker.py",
]
infiles += [
"vllm/core/scheduler.py",
"vllm/sequence.py",
"vllm/core/block_manager.py",
]
infiles += [
"vllm/model_executor/layers/sampler.py",
"vllm/sampling_params.py",
"vllm/utils.py",
]
setup(ext_modules=cythonize(infiles,
annotate=False,
force=True,
compiler_directives={
'language_level': "3",
'infer_types': True
}))
# example usage: python3 build_cython.py build_ext --inplace
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import copy
import pickle
import pytest import pytest
import torch import torch
...@@ -10,32 +9,63 @@ from vllm.compilation.pass_manager import PostGradPassManager ...@@ -10,32 +9,63 @@ from vllm.compilation.pass_manager import PostGradPassManager
from vllm.config import CompilationConfig from vllm.config import CompilationConfig
# dummy custom pass that doesn't inherit
def simple_callable(graph: torch.fx.Graph): def simple_callable(graph: torch.fx.Graph):
pass pass
callable_uuid = CallableInductorPass(simple_callable, # Should fail to add directly to the pass manager
InductorPass.hash_source(__file__)) def test_bad_callable():
config = CompilationConfig().pass_config
pass_manager = PostGradPassManager()
pass_manager.configure(config)
with pytest.raises(AssertionError):
pass_manager.add(simple_callable) # noqa, type wrong on purpose
# Pass that inherits from InductorPass
class ProperPass(InductorPass):
def __call__(self, graph: torch.fx.graph.Graph) -> None:
pass
@pytest.mark.parametrize( @pytest.mark.parametrize(
"works, callable", "callable",
[ [
(False, simple_callable), ProperPass(),
(True, callable_uuid), # Can also wrap callables in CallableInductorPass for compliance
(True, CallableInductorPass(simple_callable)), CallableInductorPass(simple_callable),
CallableInductorPass(simple_callable,
InductorPass.hash_source(__file__))
], ],
) )
def test_pass_manager(works: bool, callable): def test_pass_manager_uuid(callable):
config = CompilationConfig().pass_config config = CompilationConfig().pass_config
pass_manager = PostGradPassManager() pass_manager = PostGradPassManager()
pass_manager.configure(config) pass_manager.configure(config)
# Try to add the callable to the pass manager # Check that UUID is different if the same pass is added 2x
if works:
pass_manager.add(callable) pass_manager.add(callable)
pickle.dumps(pass_manager) uuid1 = pass_manager.uuid()
else:
with pytest.raises(AssertionError):
pass_manager.add(callable) pass_manager.add(callable)
uuid2 = pass_manager.uuid()
assert uuid1 != uuid2
# UUID should be the same as the original one,
# as we constructed in the same way.
pass_manager2 = PostGradPassManager()
pass_manager2.configure(config)
pass_manager2.add(callable)
assert uuid1 == pass_manager2.uuid()
# UUID should be different due to config change
config2 = copy.deepcopy(config)
config2.enable_fusion = not config2.enable_fusion
pass_manager3 = PostGradPassManager()
pass_manager3.configure(config2)
pass_manager3.add(callable)
assert uuid1 != pass_manager3.uuid()
...@@ -175,6 +175,8 @@ TEXT_GENERATION_MODELS = { ...@@ -175,6 +175,8 @@ TEXT_GENERATION_MODELS = {
"inceptionai/jais-13b-chat": PPTestSettings.fast(), "inceptionai/jais-13b-chat": PPTestSettings.fast(),
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(), "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(), "meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
# Tests TransformersModel
"ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(),
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(), "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
"openbmb/MiniCPM3-4B": PPTestSettings.fast(), "openbmb/MiniCPM3-4B": PPTestSettings.fast(),
# Uses Llama # Uses Llama
...@@ -243,6 +245,7 @@ TEST_MODELS = [ ...@@ -243,6 +245,7 @@ TEST_MODELS = [
# [LANGUAGE GENERATION] # [LANGUAGE GENERATION]
"microsoft/Phi-3.5-MoE-instruct", "microsoft/Phi-3.5-MoE-instruct",
"meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.2-1B-Instruct",
# "ArthurZ/Ilama-3.2-1B", NOTE: Uncomment after #13905
"ibm/PowerLM-3b", "ibm/PowerLM-3b",
# [LANGUAGE EMBEDDING] # [LANGUAGE EMBEDDING]
"intfloat/e5-mistral-7b-instruct", "intfloat/e5-mistral-7b-instruct",
......
...@@ -107,8 +107,10 @@ def test_get_gen_prompt(model, template, add_generation_prompt, ...@@ -107,8 +107,10 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
# Call the function and get the result # Call the function and get the result
result = apply_hf_chat_template( result = apply_hf_chat_template(
tokenizer, tokenizer,
trust_remote_code=True,
conversation=mock_request.messages, conversation=mock_request.messages,
chat_template=mock_request.chat_template or template_content, chat_template=mock_request.chat_template or template_content,
tools=None,
add_generation_prompt=mock_request.add_generation_prompt, add_generation_prompt=mock_request.add_generation_prompt,
continue_final_message=mock_request.continue_final_message, continue_final_message=mock_request.continue_final_message,
) )
......
...@@ -87,7 +87,7 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI, ...@@ -87,7 +87,7 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI,
choice = chat_completion.choices[0] choice = chat_completion.choices[0]
assert choice.finish_reason == "length" assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage( assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=6299, total_tokens=6309) completion_tokens=10, prompt_tokens=6287, total_tokens=6297)
message = choice.message message = choice.message
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
...@@ -180,7 +180,7 @@ async def test_single_chat_session_video_base64encoded( ...@@ -180,7 +180,7 @@ async def test_single_chat_session_video_base64encoded(
choice = chat_completion.choices[0] choice = chat_completion.choices[0]
assert choice.finish_reason == "length" assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage( assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=6299, total_tokens=6309) completion_tokens=10, prompt_tokens=6287, total_tokens=6297)
message = choice.message message = choice.message
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
......
...@@ -4,10 +4,13 @@ import warnings ...@@ -4,10 +4,13 @@ import warnings
from typing import Optional from typing import Optional
import pytest import pytest
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template, from vllm.entrypoints.chat_utils import (_resolve_hf_chat_template,
_try_extract_ast, load_chat_template,
parse_chat_messages, parse_chat_messages,
parse_chat_messages_futures, parse_chat_messages_futures,
resolve_chat_template_content_format) resolve_chat_template_content_format)
...@@ -23,8 +26,10 @@ EXAMPLES_DIR = VLLM_PATH / "examples" ...@@ -23,8 +26,10 @@ EXAMPLES_DIR = VLLM_PATH / "examples"
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b" ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B" LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
...@@ -703,25 +708,70 @@ def test_multimodal_image_parsing_matches_hf(model, image_url): ...@@ -703,25 +708,70 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
vllm_result = apply_hf_chat_template( vllm_result = apply_hf_chat_template(
tokenizer, tokenizer,
trust_remote_code=model_config.trust_remote_code,
conversation=conversation, conversation=conversation,
chat_template=None, chat_template=None,
tools=None,
add_generation_prompt=True, add_generation_prompt=True,
) )
assert hf_result == vllm_result assert hf_result == vllm_result
@pytest.mark.parametrize(
"model",
[
QWEN2VL_MODEL_ID, # tokenizer.chat_template is of type str
HERMES_MODEL_ID, # tokenizer.chat_template is of type dict
])
@pytest.mark.parametrize("use_tools", [True, False])
def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
"""checks that chat_template is a dict type for HF models."""
# Build the tokenizer group and grab the underlying tokenizer
tokenizer_group = TokenizerGroup(
model,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
)
tokenizer = tokenizer_group.tokenizer
tools = [{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": sample_json_schema
}
}] if use_tools else None
# Test detecting the tokenizer's chat_template
chat_template = _resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=tools,
trust_remote_code=True,
)
assert isinstance(chat_template, str)
# yapf: disable # yapf: disable
@pytest.mark.parametrize( @pytest.mark.parametrize(
("model", "expected_format"), ("model", "expected_format"),
[(PHI3V_MODEL_ID, "string"), [(PHI3V_MODEL_ID, "string"),
(QWEN2VL_MODEL_ID, "openai"), (QWEN2VL_MODEL_ID, "openai"),
(QWEN25VL_MODEL_ID, "openai"),
(ULTRAVOX_MODEL_ID, "string"), (ULTRAVOX_MODEL_ID, "string"),
(MLLAMA_MODEL_ID, "openai"), (MLLAMA_MODEL_ID, "openai"),
(LLAMA_GUARD_MODEL_ID, "openai")], (LLAMA_GUARD_MODEL_ID, "openai")],
) )
# yapf: enable # yapf: enable
def test_resolve_content_format_hf_defined(model, expected_format): def test_resolve_content_format_hf_defined(model, expected_format):
if model == QWEN25VL_MODEL_ID and Version(TRANSFORMERS_VERSION) < Version(
"4.49.0"):
pytest.skip("Qwen2.5-VL requires transformers>=4.49.0")
tokenizer_group = TokenizerGroup( tokenizer_group = TokenizerGroup(
model, model,
enable_lora=False, enable_lora=False,
...@@ -730,7 +780,13 @@ def test_resolve_content_format_hf_defined(model, expected_format): ...@@ -730,7 +780,13 @@ def test_resolve_content_format_hf_defined(model, expected_format):
) )
tokenizer = tokenizer_group.tokenizer tokenizer = tokenizer_group.tokenizer
chat_template = tokenizer.chat_template # Test detecting the tokenizer's chat_template
chat_template = _resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=None,
trust_remote_code=True,
)
assert isinstance(chat_template, str) assert isinstance(chat_template, str)
print("[TEXT]") print("[TEXT]")
...@@ -740,8 +796,10 @@ def test_resolve_content_format_hf_defined(model, expected_format): ...@@ -740,8 +796,10 @@ def test_resolve_content_format_hf_defined(model, expected_format):
resolved_format = resolve_chat_template_content_format( resolved_format = resolve_chat_template_content_format(
None, # Test detecting the tokenizer's chat_template None, # Test detecting the tokenizer's chat_template
None,
"auto", "auto",
tokenizer, tokenizer,
trust_remote_code=True,
) )
assert resolved_format == expected_format assert resolved_format == expected_format
...@@ -791,8 +849,10 @@ def test_resolve_content_format_examples(template_path, expected_format): ...@@ -791,8 +849,10 @@ def test_resolve_content_format_examples(template_path, expected_format):
resolved_format = resolve_chat_template_content_format( resolved_format = resolve_chat_template_content_format(
chat_template, chat_template,
None,
"auto", "auto",
dummy_tokenizer, dummy_tokenizer,
trust_remote_code=True,
) )
assert resolved_format == expected_format assert resolved_format == expected_format
# SPDX-License-Identifier: Apache-2.0
from vllm import SamplingParams
from vllm.config import LoadFormat
test_model = "openai-community/gpt2"
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
def test_model_loader_download_files(vllm_runner):
with vllm_runner(test_model,
load_format=LoadFormat.FASTSAFETENSORS) as llm:
deserialized_outputs = llm.generate(prompts, sampling_params)
assert deserialized_outputs
# SPDX-License-Identifier: Apache-2.0
import glob
import tempfile
import huggingface_hub.constants
import torch
from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf, fastsafetensors_weights_iterator,
safetensors_weights_iterator)
def test_fastsafetensors_model_loader():
with tempfile.TemporaryDirectory() as tmpdir:
huggingface_hub.constants.HF_HUB_OFFLINE = False
download_weights_from_hf("openai-community/gpt2",
allow_patterns=["*.safetensors"],
cache_dir=tmpdir)
safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True)
assert len(safetensors) > 0
fastsafetensors_tensors = {}
hf_safetensors_tensors = {}
for name, tensor in fastsafetensors_weights_iterator(
safetensors, True):
fastsafetensors_tensors[name] = tensor
for name, tensor in safetensors_weights_iterator(safetensors, True):
hf_safetensors_tensors[name] = tensor
assert len(fastsafetensors_tensors) == len(hf_safetensors_tensors)
for name, fastsafetensors_tensor in fastsafetensors_tensors.items():
fastsafetensors_tensor = fastsafetensors_tensor.to('cpu')
assert fastsafetensors_tensor.dtype == hf_safetensors_tensors[
name].dtype
assert fastsafetensors_tensor.shape == hf_safetensors_tensors[
name].shape
assert torch.all(
fastsafetensors_tensor.eq(hf_safetensors_tensors[name]))
if __name__ == "__main__":
test_fastsafetensors_model_loader()
...@@ -606,6 +606,51 @@ def test_marlin_qqq_gemm( ...@@ -606,6 +606,51 @@ def test_marlin_qqq_gemm(
assert max_diff < 0.04 assert max_diff < 0.04
def test_marlin_gemm_subset_input():
quant_type = scalar_types.uint4b8
group_size = 128
size_m, size_k, size_n = 32, 1024, 2048
big_m = size_m * 2
big_k = size_k * 2
a_input = rand_data((big_m, big_k))[8:size_m + 8, 8:size_k + 8]
b_weight = rand_data((size_k, size_n))
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, False)
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
output = ops.gptq_marlin_gemm(
a_input,
marlin_q_w,
marlin_s,
marlin_zp,
g_idx,
sort_indices,
workspace.scratch,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full=True,
has_zp=False,
use_atomic_add=False,
use_fp32_reduce=True,
is_zp_float=False,
)
output_ref = torch.matmul(a_input, w_ref)
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04
def test_marlin_gemm_opcheck(): def test_marlin_gemm_opcheck():
size_m = 2048 size_m = 2048
size_n = 4096 size_n = 4096
......
...@@ -3,8 +3,11 @@ ...@@ -3,8 +3,11 @@
Run `pytest tests/kernels/test_moe.py`. Run `pytest tests/kernels/test_moe.py`.
""" """
import pytest import pytest
import torch import torch
from torch.nn import Parameter
from torch.nn import functional as F
from transformers import MixtralConfig from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
...@@ -37,6 +40,7 @@ TOP_KS = [2, 6] ...@@ -37,6 +40,7 @@ TOP_KS = [2, 6]
@pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False])
def test_fused_moe( def test_fused_moe(
m: int, m: int,
n: int, n: int,
...@@ -45,6 +49,7 @@ def test_fused_moe( ...@@ -45,6 +49,7 @@ def test_fused_moe(
topk: int, topk: int,
ep_size: int, ep_size: int,
dtype: torch.dtype, dtype: torch.dtype,
padding: bool,
): ):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
...@@ -65,7 +70,8 @@ def test_fused_moe( ...@@ -65,7 +70,8 @@ def test_fused_moe(
else: else:
e_map = None e_map = None
triton_output = fused_moe(a, torch_output = torch_moe(a, w1, w2, score, topk, e_map)
iterative_output = iterative_moe(a,
w1, w1,
w2, w2,
score, score,
...@@ -73,9 +79,15 @@ def test_fused_moe( ...@@ -73,9 +79,15 @@ def test_fused_moe(
global_num_experts=e, global_num_experts=e,
expert_map=e_map, expert_map=e_map,
renormalize=False) renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) # Pad the weight if moe padding is enabled
iterative_output = iterative_moe(a, if padding:
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
torch.cuda.empty_cache()
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
torch.cuda.empty_cache()
triton_output = fused_moe(a,
w1, w1,
w2, w2,
score, score,
...@@ -83,6 +95,7 @@ def test_fused_moe( ...@@ -83,6 +95,7 @@ def test_fused_moe(
global_num_experts=e, global_num_experts=e,
expert_map=e_map, expert_map=e_map,
renormalize=False) renormalize=False)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
torch.testing.assert_close(iterative_output, torch.testing.assert_close(iterative_output,
torch_output, torch_output,
atol=2e-2, atol=2e-2,
...@@ -202,8 +215,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, ...@@ -202,8 +215,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
@pytest.mark.parametrize("dtype", @pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16]) [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False])
@torch.inference_mode() @torch.inference_mode()
def test_mixtral_moe(dtype: torch.dtype): def test_mixtral_moe(dtype: torch.dtype, padding: bool):
"""Make sure our Mixtral MoE implementation agrees with the one from """Make sure our Mixtral MoE implementation agrees with the one from
huggingface.""" huggingface."""
...@@ -233,6 +247,17 @@ def test_mixtral_moe(dtype: torch.dtype): ...@@ -233,6 +247,17 @@ def test_mixtral_moe(dtype: torch.dtype):
# vLLM uses 1D query [num_tokens, hidden_dim] # vLLM uses 1D query [num_tokens, hidden_dim]
vllm_inputs = hf_inputs.flatten(0, 1) vllm_inputs = hf_inputs.flatten(0, 1)
# Pad the weight if moe padding is enabled
if padding:
vllm_moe.experts.w13_weight = Parameter(F.pad(
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128],
requires_grad=False)
torch.cuda.empty_cache()
vllm_moe.experts.w2_weight = Parameter(F.pad(
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128],
requires_grad=False)
torch.cuda.empty_cache()
# Run forward passes for both MoE blocks # Run forward passes for both MoE blocks
hf_states, _ = hf_moe.forward(hf_inputs) hf_states, _ = hf_moe.forward(hf_inputs)
vllm_states = vllm_moe.forward(vllm_inputs) vllm_states = vllm_moe.forward(vllm_inputs)
......
...@@ -39,7 +39,10 @@ def ensure_system_prompt(messages: list[dict[str, Any]], ...@@ -39,7 +39,10 @@ def ensure_system_prompt(messages: list[dict[str, Any]],
# universal args for all models go here. also good if you need to test locally # universal args for all models go here. also good if you need to test locally
# and change type or KV cache quantization or something. # and change type or KV cache quantization or something.
ARGS: list[str] = ["--enable-auto-tool-choice", "--max-model-len", "1024"] ARGS: list[str] = [
"--enable-auto-tool-choice", "--max-model-len", "1024", "--max-num-seqs",
"256"
]
CONFIGS: dict[str, ServerConfig] = { CONFIGS: dict[str, ServerConfig] = {
"hermes": { "hermes": {
......
...@@ -5,11 +5,15 @@ import os ...@@ -5,11 +5,15 @@ import os
import tempfile import tempfile
import depyf import depyf
import pytest
from vllm.config import CompilationLevel from vllm.config import CompilationLevel
temp_dir = tempfile.mkdtemp()
with depyf.prepare_debug(temp_dir): @pytest.mark.skip(reason="Not working; needs investigation.")
def test_tpu_compilation():
temp_dir = tempfile.mkdtemp()
with depyf.prepare_debug(temp_dir):
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
prompts = [ prompts = [
...@@ -46,51 +50,51 @@ with depyf.prepare_debug(temp_dir): ...@@ -46,51 +50,51 @@ with depyf.prepare_debug(temp_dir):
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
assert generated_text.startswith(answer) assert generated_text.startswith(answer)
compiled_codes = sorted( compiled_codes = sorted(
glob.glob(os.path.join(temp_dir, "__transformed_code*.py"))) glob.glob(os.path.join(temp_dir, "__transformed_code*.py")))
for i, compiled_code in enumerate(compiled_codes): for i, compiled_code in enumerate(compiled_codes):
print("{} file: {}".format(i + 1, compiled_code)) print("{} file: {}".format(i + 1, compiled_code))
# We should only trigger Dynamo compilation 4 times: # We should only trigger Dynamo compilation 4 times:
# 1. forward pass (symbolic) # 1. forward pass (symbolic)
# 2. compute_logits (symbolic) # 2. compute_logits (symbolic)
# 3. forward pass (shape 16) # 3. forward pass (shape 16)
# 4. forward pass (shape 32) # 4. forward pass (shape 32)
# and later calls should not trigger Dynamo compilation again. # and later calls should not trigger Dynamo compilation again.
# NOTE: It might still trigger XLA compilation. # NOTE: It might still trigger XLA compilation.
# Check we have 4 compiled codes # Check we have 4 compiled codes
assert len(compiled_codes) == 4 assert len(compiled_codes) == 4
kv_cache_prefix = "kv_cache" kv_cache_prefix = "kv_cache"
attn_prefix = "ragged_paged_attention" attn_prefix = "ragged_paged_attention"
# Check all the compilations are as expected # Check all the compilations are as expected
compiled_fns = sorted( compiled_fns = sorted(
glob.glob(os.path.join(temp_dir, "__compiled_fn*Captured*.py"))) glob.glob(os.path.join(temp_dir, "__compiled_fn*Captured*.py")))
for i, compiled_fn in enumerate(compiled_fns): for i, compiled_fn in enumerate(compiled_fns):
print("{} file: {}".format(i + 1, compiled_fn)) print("{} file: {}".format(i + 1, compiled_fn))
# The first compilation is symbolic, so it should not have any kv_caches # The first compilation is symbolic, so it should not have any kv_caches
with open(compiled_fns[0]) as f: with open(compiled_fns[0]) as f:
content = f.read() content = f.read()
assert kv_cache_prefix not in content assert kv_cache_prefix not in content
# The second compilation is symbolic, so it should not have any kv_caches # The second compilation is symbolic, so it should not have any kv_caches
with open(compiled_fns[1]) as f: with open(compiled_fns[1]) as f:
content = f.read() content = f.read()
assert kv_cache_prefix not in content assert kv_cache_prefix not in content
# The third compilation is shape 16, so it should have kv_caches and the # The third compilation is shape 16, so it should have kv_caches and the
# ragged_paged_attention # ragged_paged_attention
with open(compiled_fns[2]) as f: with open(compiled_fns[2]) as f:
content = f.read() content = f.read()
assert (kv_cache_prefix in content and attn_prefix in content) assert (kv_cache_prefix in content and attn_prefix in content)
# The forth compilation is shape 32, so it should have kv_caches and the # The forth compilation is shape 32, so it should have kv_caches and the
# ragged_paged_attention # ragged_paged_attention
with open(compiled_fns[3]) as f: with open(compiled_fns[3]) as f:
content = f.read() content = f.read()
assert (kv_cache_prefix in content and attn_prefix in content) assert (kv_cache_prefix in content and attn_prefix in content)
...@@ -11,11 +11,13 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST, ...@@ -11,11 +11,13 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
STOP_STRINGS, STOP_STRINGS,
DummyOutputProcessorTestVectors, DummyOutputProcessorTestVectors,
MockEngineCore) MockEngineCore)
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import PromptLogprobs, SampleLogprobs from vllm.sequence import PromptLogprobs, SampleLogprobs
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.output_processor import (OutputProcessor,
RequestOutputCollector)
from vllm.v1.metrics.stats import IterationStats from vllm.v1.metrics.stats import IterationStats
...@@ -834,3 +836,88 @@ def test_iteration_stats(dummy_test_vectors): ...@@ -834,3 +836,88 @@ def test_iteration_stats(dummy_test_vectors):
assert iteration_stats.num_prompt_tokens == 0 assert iteration_stats.num_prompt_tokens == 0
assert iteration_stats.num_generation_tokens == num_active assert iteration_stats.num_generation_tokens == num_active
@pytest.mark.asyncio
async def test_request_output_collector():
NUM_REQS = 3
TEXT = "a"
def make_outputs() -> list[RequestOutput]:
return [
RequestOutput(
request_id="my-request-id",
prompt=None,
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[
CompletionOutput(
index=0,
text=TEXT,
token_ids=[idx],
cumulative_logprob=(idx + 1 * 1.0),
logprobs=[{
"a": idx,
"b": idx
}],
finish_reason="length" if
(idx == NUM_REQS - 1) else None,
)
],
finished=(idx == NUM_REQS - 1),
) for idx in range(NUM_REQS)
]
collector = RequestOutputCollector(RequestOutputKind.DELTA)
# CASE 1: Put then get.
outputs = make_outputs()
collector.put(outputs[0])
output = await collector.get()
assert not collector.ready.is_set()
assert collector.output is None
assert output.outputs[0].text == "a"
assert output.outputs[0].token_ids == [0]
# CASE 2: 2 puts then get.
num_to_put = 2
outputs = make_outputs()
for i in range(num_to_put):
collector.put(outputs[i])
output = await collector.get()
assert not collector.ready.is_set()
assert collector.output is None
assert not output.finished
# Text, token_ids, and logprobs should get merged.
assert output.outputs[0].text == TEXT * num_to_put
for tok_0, tok_1 in zip(output.outputs[0].token_ids,
list(range(num_to_put))):
assert tok_0 == tok_1
assert len(output.outputs[0].logprobs) == num_to_put
# Cumulative logprobs should be the last one.
cumulative_logprob_expected = 1.0 * num_to_put
assert output.outputs[0].cumulative_logprob == cumulative_logprob_expected
# CASE 3: Put all 3 (including a finished).
num_to_put = 3
outputs = make_outputs()
for i in range(num_to_put):
collector.put(outputs[i])
output = await collector.get()
assert not collector.ready.is_set()
assert collector.output is None
assert output.finished
assert output.outputs[0].finish_reason == "length"
# Text, token_ids, and logprobs should get merged.
assert output.outputs[0].text == TEXT * num_to_put
for tok_0, tok_1 in zip(output.outputs[0].token_ids,
list(range(num_to_put))):
assert tok_0 == tok_1
assert len(output.outputs[0].logprobs) == num_to_put
# Cumulative logprobs should be the last one.
cumulative_logprob_expected = 1.0 * num_to_put
assert output.outputs[0].cumulative_logprob == cumulative_logprob_expected
...@@ -13,7 +13,7 @@ from vllm.entrypoints.llm import LLM ...@@ -13,7 +13,7 @@ from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sampling_params import GuidedDecodingParams, SamplingParams
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"] GUIDED_DECODING_BACKENDS_V1 = ["xgrammar", "guidance"]
MODELS_TO_TEST = [ MODELS_TO_TEST = [
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410" "Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
] ]
...@@ -30,12 +30,13 @@ def test_guided_json_completion( ...@@ -30,12 +30,13 @@ def test_guided_json_completion(
model_name: str, model_name: str,
): ):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024) llm = LLM(model=model_name,
sampling_params = SamplingParams(temperature=1.0, max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams( guided_decoding=GuidedDecodingParams(json=sample_json_schema))
json=sample_json_schema,
backend=guided_decoding_backend))
outputs = llm.generate(prompts=[ outputs = llm.generate(prompts=[
f"Give an example JSON for an employee profile " f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}" f"that fits this schema: {sample_json_schema}"
...@@ -111,13 +112,14 @@ def test_guided_json_object( ...@@ -111,13 +112,14 @@ def test_guided_json_object(
model_name: str, model_name: str,
): ):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024) llm = LLM(model=model_name,
sampling_params = SamplingParams(temperature=1.0, max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=100, max_tokens=100,
n=2, n=2,
guided_decoding=GuidedDecodingParams( guided_decoding=GuidedDecodingParams(json_object=True))
json_object=True,
backend=guided_decoding_backend))
outputs = llm.generate( outputs = llm.generate(
prompts=("Generate a JSON object with curly braces for a person with " prompts=("Generate a JSON object with curly braces for a person with "
...@@ -137,12 +139,20 @@ def test_guided_json_object( ...@@ -137,12 +139,20 @@ def test_guided_json_object(
# Parse to verify it is valid JSON # Parse to verify it is valid JSON
parsed_json = json.loads(generated_text) parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict) allowed_types: tuple[type, ...] = (dict, )
if guided_decoding_backend == "xgrammar":
# TODO - we are currently too permissive with xgrammar and
# allow # any valid json (typically comes back as a list or
# object). We can fix this by specifying a jsonschema of
# {"type": "object"}, # but we need this fix in a release
# first: https://github.com/mlc-ai/xgrammar/pull/264
allowed_types = (dict, list)
assert isinstance(parsed_json, allowed_types)
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1) GUIDED_DECODING_BACKENDS_V1 + ["auto"])
@pytest.mark.parametrize("model_name", MODELS_TO_TEST) @pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_json_unsupported_schema( def test_guided_json_unsupported_schema(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
...@@ -151,12 +161,14 @@ def test_guided_json_unsupported_schema( ...@@ -151,12 +161,14 @@ def test_guided_json_unsupported_schema(
model_name: str, model_name: str,
): ):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024) llm = LLM(model=model_name,
sampling_params = SamplingParams(temperature=1.0, max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams( guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
json=unsupported_json_schema, if guided_decoding_backend == "xgrammar":
backend=guided_decoding_backend))
with pytest.raises(ValueError, with pytest.raises(ValueError,
match="The provided JSON schema contains features " match="The provided JSON schema contains features "
"not supported by xgrammar."): "not supported by xgrammar."):
...@@ -166,6 +178,26 @@ def test_guided_json_unsupported_schema( ...@@ -166,6 +178,26 @@ def test_guided_json_unsupported_schema(
] * 2, ] * 2,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True) use_tqdm=True)
else:
# This should work for both "guidance" and "auto".
outputs = llm.generate(
prompts=("Give an example JSON object for a grade "
"that fits this schema: "
f"{unsupported_json_schema}"),
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
generated_text = output.outputs[0].text
assert generated_text is not None
print(generated_text)
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
...@@ -179,13 +211,14 @@ def test_guided_grammar_ebnf( ...@@ -179,13 +211,14 @@ def test_guided_grammar_ebnf(
model_name: str, model_name: str,
): ):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024) llm = LLM(model=model_name,
sampling_params = SamplingParams(temperature=0.8, max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95, top_p=0.95,
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams( guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
grammar=sample_sql_ebnf,
backend=guided_decoding_backend))
outputs = llm.generate( outputs = llm.generate(
prompts=("Generate a sql statement that selects col_1 from " prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"), "table_1 where it is equal to 1"),
...@@ -222,13 +255,14 @@ def test_guided_grammar_lark( ...@@ -222,13 +255,14 @@ def test_guided_grammar_lark(
model_name: str, model_name: str,
): ):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024) llm = LLM(model=model_name,
sampling_params = SamplingParams(temperature=0.8, max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95, top_p=0.95,
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams( guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
grammar=sample_sql_lark,
backend=guided_decoding_backend))
outputs = llm.generate( outputs = llm.generate(
prompts=("Generate a sql statement that selects col_1 from " prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"), "table_1 where it is equal to 1"),
...@@ -269,16 +303,15 @@ def test_guided_grammar_ebnf_invalid( ...@@ -269,16 +303,15 @@ def test_guided_grammar_ebnf_invalid(
model_name: str, model_name: str,
): ):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024) llm = LLM(model=model_name,
sampling_params = SamplingParams(temperature=0.8, max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95, top_p=0.95,
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams( guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
grammar="not a grammar", with pytest.raises(ValueError, match="Failed to convert the grammar "):
backend=guided_decoding_backend))
with pytest.raises(ValueError,
match="Failed to convert the grammar "
"from Lark to EBNF."):
llm.generate( llm.generate(
prompts=("Generate a sql statement that selects col_1 from " prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"), "table_1 where it is equal to 1"),
...@@ -298,12 +331,13 @@ def test_guided_regex( ...@@ -298,12 +331,13 @@ def test_guided_regex(
model_name: str, model_name: str,
): ):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024) llm = LLM(model=model_name,
sampling_params = SamplingParams(temperature=0.8, max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95, top_p=0.95,
guided_decoding=GuidedDecodingParams( guided_decoding=GuidedDecodingParams(regex=sample_regex))
regex=sample_regex,
backend=guided_decoding_backend))
outputs = llm.generate( outputs = llm.generate(
prompts=[ prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}" f"Give an example IPv4 address with this regex: {sample_regex}"
...@@ -335,12 +369,13 @@ def test_guided_choice_completion( ...@@ -335,12 +369,13 @@ def test_guided_choice_completion(
model_name: str, model_name: str,
): ):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024) llm = LLM(model=model_name,
sampling_params = SamplingParams(temperature=0.8, max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95, top_p=0.95,
guided_decoding=GuidedDecodingParams( guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
choice=sample_guided_choice,
backend=guided_decoding_backend))
outputs = llm.generate( outputs = llm.generate(
prompts="The best language for type-safe systems programming is ", prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params, sampling_params=sampling_params,
......
...@@ -36,6 +36,8 @@ def create_logits_tensor(output_token_ids: list[list[int]], ...@@ -36,6 +36,8 @@ def create_logits_tensor(output_token_ids: list[list[int]],
def create_sampling_metadata( def create_sampling_metadata(
all_greedy: bool, all_greedy: bool,
temperature: Optional[torch.Tensor] = None, temperature: Optional[torch.Tensor] = None,
top_k: Optional[torch.Tensor] = None,
top_p: Optional[torch.Tensor] = None,
generators: Optional[dict[int, Any]] = None, generators: Optional[dict[int, Any]] = None,
) -> SamplingMetadata: ) -> SamplingMetadata:
"""Create a v1 sampling metadata object with all_greedy set """Create a v1 sampling metadata object with all_greedy set
...@@ -52,8 +54,8 @@ def create_sampling_metadata( ...@@ -52,8 +54,8 @@ def create_sampling_metadata(
temperature=temperature, temperature=temperature,
all_greedy=all_greedy, all_greedy=all_greedy,
all_random=not all_greedy, all_random=not all_greedy,
top_p=None, top_p=top_p,
top_k=None, top_k=top_k,
min_p=torch.empty(1, ), min_p=torch.empty(1, ),
generators=generators, generators=generators,
max_num_logprobs=0, max_num_logprobs=0,
...@@ -462,3 +464,147 @@ def estimate_rejection_sampling_pdf( ...@@ -462,3 +464,147 @@ def estimate_rejection_sampling_pdf(
density=True) density=True)
return hist.hist return hist.hist
def _test_masked_logits(
rejection_sampler,
batch_size: int,
num_draft_tokens: int,
vocab_size: int,
target_logits: torch.Tensor,
unmasked_indices: torch.Tensor,
sampling_metadata: SamplingMetadata,
):
# Set up test parameters
num_tokens = batch_size * num_draft_tokens
# Create random draft probabilities.
draft_probs = torch.rand((num_tokens, vocab_size),
dtype=torch.float32,
device=DEVICE)
draft_probs = F.softmax(draft_probs, dim=-1)
# Randomly sample draft token ids from draft probs
draft_token_ids = torch.multinomial(draft_probs, num_samples=1)
draft_token_ids = draft_token_ids.reshape(batch_size, num_draft_tokens)
draft_token_ids = draft_token_ids.tolist()
# Bonus tokens not used but required
bonus_token_ids = torch.zeros((batch_size, 1),
dtype=torch.int64,
device=DEVICE)
# Create spec decode metadata
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids,
device=DEVICE,
)
# Run rejection sampling
output_token_ids = rejection_sampler(
spec_decode_metadata,
draft_probs=draft_probs,
target_logits=target_logits,
bonus_token_ids=bonus_token_ids,
sampling_metadata=sampling_metadata,
)
# Remove bonus tokens and reshape
output_token_ids = output_token_ids[:, :-1].flatten().tolist()
# Check that all sampled tokens are within the unmasked indices.
for i in range(num_tokens):
token_id = output_token_ids[i]
if token_id == PLACEHOLDER_TOKEN_ID:
continue
assert token_id in unmasked_indices[i]
@pytest.mark.parametrize("top_k", [1, 5, 99])
def test_top_k(rejection_sampler, top_k):
"""Test rejection sampling with top-k sampling"""
vocab_size = 100
batch_size = 100
num_draft_tokens = 3
num_tokens = batch_size * num_draft_tokens
# Randomly create top-k indices.
top_k_indices = [
torch.randperm(vocab_size, device=DEVICE)[:top_k]
for _ in range(num_tokens)
]
top_k_indices = torch.stack(top_k_indices)
# Create logits with the uniform distribution.
target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE)
# Increment the logits for top-k indices, a little bit more than the other
# ones. If the masking is effective, the non-topk indices will never be
# sampled despite the small difference in logits.
for i in range(num_tokens):
target_logits[i, top_k_indices[i]] += 0.1
# Create sampling metadata
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
sampling_metadata = create_sampling_metadata(
all_greedy=False,
temperature=temperature,
top_k=torch.tensor([top_k] * batch_size,
device=DEVICE,
dtype=torch.int64),
)
_test_masked_logits(
rejection_sampler,
batch_size=batch_size,
num_draft_tokens=num_draft_tokens,
vocab_size=vocab_size,
target_logits=target_logits,
unmasked_indices=top_k_indices,
sampling_metadata=sampling_metadata,
)
@pytest.mark.parametrize("top_p", [0.5, 0.9, 0.99])
def test_top_p(rejection_sampler, top_p):
"""Test rejection sampling with top-p sampling"""
vocab_size = 100
batch_size = 100
num_draft_tokens = 3
num_tokens = batch_size * num_draft_tokens
# Create logits with the uniform distribution.
target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE)
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
rescaled_logits = target_logits / temperature
logits_sort, logits_idx = rescaled_logits.sort(dim=-1, descending=False)
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= 1 - top_p
# at least one
top_p_mask[:, -1] = False
# Get the top-p indices.
top_p_indices = []
for i in range(num_tokens):
top_p_indices.append(logits_idx[i][~top_p_mask[i]].tolist())
# Create sampling metadata
sampling_metadata = create_sampling_metadata(
all_greedy=False,
temperature=temperature,
top_p=torch.tensor([top_p] * batch_size,
device=DEVICE,
dtype=torch.float32),
)
_test_masked_logits(
rejection_sampler,
batch_size=batch_size,
num_draft_tokens=num_draft_tokens,
vocab_size=vocab_size,
target_logits=target_logits,
unmasked_indices=top_p_indices,
sampling_metadata=sampling_metadata,
)
...@@ -22,12 +22,13 @@ from vllm.attention.backends.utils import ( ...@@ -22,12 +22,13 @@ from vllm.attention.backends.utils import (
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
get_seq_len_block_table_args, is_all_cross_attn_metadata_set, get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty) is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.vllm_flash_attn import (flash_attn_varlen_func, from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache) flash_attn_with_kvcache)
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
...@@ -632,10 +633,13 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -632,10 +633,13 @@ class FlashAttentionImpl(AttentionImpl):
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.vllm_flash_attn_version = get_flash_attn_version( self.vllm_flash_attn_version = get_flash_attn_version(
requires_alibi=self.alibi_slopes is not None) requires_alibi=self.alibi_slopes is not None)
if (is_quantized_kv_cache(self.kv_cache_dtype) if is_quantized_kv_cache(self.kv_cache_dtype) and (
and self.vllm_flash_attn_version != 3): not self.kv_cache_dtype.startswith("fp8")
or not flash_attn_supports_fp8()):
raise NotImplementedError( raise NotImplementedError(
"Only FlashAttention3 supports FP8 KV cache") f"FlashAttention does not support {self.kv_cache_dtype} "
"kv-cache on this device "
f"(FA supports fp8 = {flash_attn_supports_fp8()}).")
if logits_soft_cap is None: if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap. # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0 logits_soft_cap = 0
...@@ -704,6 +708,10 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -704,6 +708,10 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap: Optional[float] = self.logits_soft_cap logits_soft_cap: Optional[float] = self.logits_soft_cap
fp8_attention = kv_cache_dtype.startswith("fp8") fp8_attention = kv_cache_dtype.startswith("fp8")
if fp8_attention and not flash_attn_supports_fp8():
raise NotImplementedError(
"FlashAttention does not support FP8 kv-cache on this device.")
if kv_cache.numel() > 0: if kv_cache.numel() > 0:
key_cache = kv_cache[0] key_cache = kv_cache[0]
value_cache = kv_cache[1] value_cache = kv_cache[1]
......
...@@ -206,7 +206,6 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, ...@@ -206,7 +206,6 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx, compute_slot_mapping_start_idx,
is_block_tables_empty) is_block_tables_empty)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.fa_utils import get_flash_attn_version
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear, LinearBase, RowParallelLinear,
UnquantizedLinearMethod) UnquantizedLinearMethod)
...@@ -215,6 +214,7 @@ from vllm.model_executor.layers.rotary_embedding import ( ...@@ -215,6 +214,7 @@ from vllm.model_executor.layers.rotary_embedding import (
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
try: try:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import hashlib import hashlib
import importlib.metadata
import inspect import inspect
import json
import types import types
from abc import ABC, abstractmethod from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Optional, Union
import torch import torch
from packaging.version import Version
from torch import fx from torch import fx
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
from torch._inductor.custom_graph_pass import CustomGraphPass
else:
# CustomGraphPass is not present in 2.5 or lower, import our version
from .torch25_custom_graph_pass import ( # noqa: yapf
Torch25CustomGraphPass as CustomGraphPass)
class InductorPass(ABC):
"""
General custom inductor pass interface.
"""
@abstractmethod class InductorPass(CustomGraphPass):
def __call__(self, graph: torch.fx.Graph):
""" """
Execute the pass on the given graph. A custom graph pass that uses a hash of its source as the UUID.
This is defined as a convenience and should work in most cases.
""" """
raise NotImplementedError
def uuid(self) -> Any: def uuid(self) -> Any:
""" """
...@@ -48,7 +51,16 @@ class InductorPass(ABC): ...@@ -48,7 +51,16 @@ class InductorPass(ABC):
else: else:
src_str = inspect.getsource(src.__class__) src_str = inspect.getsource(src.__class__)
hasher.update(src_str.encode("utf-8")) hasher.update(src_str.encode("utf-8"))
return hasher.digest() return hasher.hexdigest()
@staticmethod
def hash_dict(dict_: Dict[Any, Any]):
"""
Utility method to hash a dictionary, can alternatively be used for uuid.
:return: A sha256 hash of the json rep of the dictionary.
"""
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()
class CallableInductorPass(InductorPass): class CallableInductorPass(InductorPass):
...@@ -61,25 +73,10 @@ class CallableInductorPass(InductorPass): ...@@ -61,25 +73,10 @@ class CallableInductorPass(InductorPass):
callable: Callable[[fx.Graph], None], callable: Callable[[fx.Graph], None],
uuid: Optional[Any] = None): uuid: Optional[Any] = None):
self.callable = callable self.callable = callable
if uuid is None: self._uuid = self.hash_source(callable) if uuid is None else uuid
uuid = InductorPass.hash_source(callable)
self._uuid = uuid
def __call__(self, graph: torch.fx.Graph): def __call__(self, graph: torch.fx.Graph):
self.callable(graph) self.callable(graph)
def uuid(self) -> Any: def uuid(self) -> Any:
return self._uuid return self._uuid
def __getstate__(self):
"""
Pickling occurs in the Inductor code cache if a pass is not given to
the pass manager but is instead directly added to config as a pass.
See PostGradPassManager for more.
TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
"""
return self._uuid
def __setstate__(self, state):
raise ValueError("Cannot unpickle CallableInductorPass")
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List from typing import List
import torch
from torch import fx as fx from torch import fx as fx
from vllm.config import CompilationConfig from vllm.config import CompilationConfig
...@@ -10,29 +9,18 @@ from vllm.logger import init_logger ...@@ -10,29 +9,18 @@ from vllm.logger import init_logger
from .fix_functionalization import FixFunctionalizationPass from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass from .fusion import FusionPass
from .inductor_pass import InductorPass from .inductor_pass import CustomGraphPass, InductorPass
from .noop_elimination import NoOpEliminationPass from .noop_elimination import NoOpEliminationPass
logger = init_logger(__name__) logger = init_logger(__name__)
class PlaceHolder: class PostGradPassManager(CustomGraphPass):
pass
if torch.__version__ < "2.6":
Parent = PlaceHolder # type: ignore
else:
Parent = torch._inductor.custom_graph_pass.CustomGraphPass # type: ignore
class PostGradPassManager(Parent):
""" """
The pass manager for post-grad passes. The pass manager for post-grad passes.
It handles configuration, adding custom passes, and running passes. It handles configuration, adding custom passes, and running passes.
It also supports pickling, which is used by the Inductor code cache. It supports uuid for the Inductor code cache. That includes torch<2.6
TODO(torch==2.6), use CustomGraphPass support using pickling (in .inductor_pass.CustomGraphPass).
(torch._inductor.custom_graph_pass.CustomGraphPass)
The order of the post-grad post-passes is: The order of the post-grad post-passes is:
1. passes (constructor parameter) 1. passes (constructor parameter)
...@@ -67,27 +55,13 @@ class PostGradPassManager(Parent): ...@@ -67,27 +55,13 @@ class PostGradPassManager(Parent):
self.passes.append(pass_) self.passes.append(pass_)
def uuid(self): def uuid(self):
return self.__getstate__()
def __getstate__(self) -> Dict[str, List[Any]]:
""" """
Custom pickling for the pass manager, as some passes cannot be pickled. The PostGradPassManager is set as a custom pass in the Inductor and
Pickling occurs because the pass manager is set as the value of affects compilation caching. Its uuid depends on the UUIDs of all
`config["post_grad_custom_post_pass"]` in the Inductor config. dependent passes and the pass config. See InductorPass for more info.
The config is pickled to act as a key in the Inductor code cache.
Any other passes in the config are pickled as well.
TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
""" """
state = {"pass_config": self.pass_config.uuid(), "passes": []} state = {"pass_config": self.pass_config.uuid(), "passes": []}
for pass_ in self.passes: for pass_ in self.passes:
state["passes"].append(pass_.uuid()) state["passes"].append(pass_.uuid())
state["passes"].append(self.fix_functionalization.uuid()) state["passes"].append(self.fix_functionalization.uuid())
return state return InductorPass.hash_dict(state)
def __setstate__(self, state):
"""
Do not allow unpickling of the pass manager.
If this is needed in the future, it should properly pickle the passes.
"""
raise ValueError("Cannot unpickle PostGradPassManager")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment