utils.py 3.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
5
6
7
8
import random

import pytest
import torch

9
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
10
from vllm.platforms import current_platform
11
from vllm.utils.flashinfer import has_flashinfer
12
13

skip_unsupported = pytest.mark.skipif(
14
15
16
17
    not (current_platform.is_cuda() and current_platform.has_device_capability(80)),
    # Supports testing on Ampere and Ada Lovelace devices.
    # Note: For devices with SM < 90, batch invariance does not support CUDA Graphs.
    reason="Requires CUDA and >= Ampere (SM80)",
18
19
)

20
21
BACKENDS: list[str] = [
    "FLASH_ATTN",
22
    "TRITON_MLA",
23
24
]

25
26
27
if has_flashinfer():
    BACKENDS.append("FLASHINFER")

28
if flash_attn_supports_mla():
29
30
31
32
33
34
35
36
37
38
39
40
41
    BACKENDS.append("FLASH_ATTN_MLA")

DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat"


def resolve_model_name(backend: str) -> str:
    """Resolve the model name for the given backend."""
    model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
    if backend.endswith("MLA") and model == DEFAULT_MODEL:
        return MLA_MODEL
    return model

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
    # Generate more realistic prompts that will actually produce varied tokens
    # Use a mix of common English text patterns

    prompt_templates = [
        # Question-answer style
        "Question: What is the capital of France?\nAnswer: The capital of France is",
        "Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
        "User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
        # Story/narrative style
        "Once upon a time in a distant galaxy, there lived",
        "The old man walked slowly down the street, remembering",
        "In the year 2157, humanity finally discovered",
        # Technical/code style
        "To implement a binary search tree in Python, first we need to",
        "The algorithm works by iterating through the array and",
        "Here's how to optimize database queries using indexing:",
        # Factual/informative style
        "The Renaissance was a period in European history that",
        "Climate change is caused by several factors including",
        "The human brain contains approximately 86 billion neurons which",
        # Conversational style
        "I've been thinking about getting a new laptop because",
        "Yesterday I went to the store and bought",
        "My favorite thing about summer is definitely",
    ]

    # Pick a random template
    base_prompt = random.choice(prompt_templates)

    if max_words < min_words:
        max_words = min_words
    target_words = random.randint(min_words, max_words)

    if target_words > 50:
        # For longer prompts, repeat context
        padding_text = (
            " This is an interesting topic that deserves more explanation. "
            * (target_words // 50)
        )
        base_prompt = base_prompt + padding_text

    return base_prompt


def _extract_step_logprobs(request_output):
    if getattr(request_output, "outputs", None):
        inner = request_output.outputs[0]
        if hasattr(inner, "logprobs") and inner.logprobs is not None:
            t = torch.tensor(
                [
                    inner.logprobs[i][tid].logprob
                    for i, tid in enumerate(inner.token_ids)
                ],
                dtype=torch.float32,
            )
            return t, inner.token_ids

    return None, None
102
103
104
105


def is_device_capability_below_90() -> bool:
    return not current_platform.has_device_capability(90)