"vscode:/vscode.git/clone" did not exist on "c757a15f0f8ed54b7f85b849d5e075226fabbcd9"
utils.py 3.95 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
5
6
7
8
9
import random

import pytest
import torch

from vllm.platforms import current_platform
10
11
12
13
from vllm.transformers_utils.config import get_config
from vllm.transformers_utils.model_arch_config_convertor import (
    ModelArchConfigConvertorBase,
)
14
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
15
16

skip_unsupported = pytest.mark.skipif(
17
18
19
20
    not (current_platform.is_cuda() and current_platform.has_device_capability(80)),
    # Supports testing on Ampere and Ada Lovelace devices.
    # Note: For devices with SM < 90, batch invariance does not support CUDA Graphs.
    reason="Requires CUDA and >= Ampere (SM80)",
21
22
)

23
24
25
DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
TEST_MODEL = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)

26
27
BACKENDS: list[str] = [
    "FLASH_ATTN",
28
    "TRITON_ATTN",
29
30
]

31
32
33
34
# FlashInfer temporarily disabled due to invariant CTA sizes.
# See FlashInfer issue #2424
# if has_flashinfer():
#     BACKENDS.append("FLASHINFER")
35

36
37
38
39
40
41
42
# only run MLA backends when the requested test model is itself an MLA model.
if os.getenv("VLLM_TEST_MODEL"):
    config = get_config(TEST_MODEL, trust_remote_code=False)
    if ModelArchConfigConvertorBase(config, config.get_text_config()).is_deepseek_mla():
        BACKENDS = ["TRITON_MLA"]
        if flash_attn_supports_mla():
            BACKENDS.append("FLASH_ATTN_MLA")
43

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
    # Generate more realistic prompts that will actually produce varied tokens
    # Use a mix of common English text patterns

    prompt_templates = [
        # Question-answer style
        "Question: What is the capital of France?\nAnswer: The capital of France is",
        "Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
        "User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
        # Story/narrative style
        "Once upon a time in a distant galaxy, there lived",
        "The old man walked slowly down the street, remembering",
        "In the year 2157, humanity finally discovered",
        # Technical/code style
        "To implement a binary search tree in Python, first we need to",
        "The algorithm works by iterating through the array and",
        "Here's how to optimize database queries using indexing:",
        # Factual/informative style
        "The Renaissance was a period in European history that",
        "Climate change is caused by several factors including",
        "The human brain contains approximately 86 billion neurons which",
        # Conversational style
        "I've been thinking about getting a new laptop because",
        "Yesterday I went to the store and bought",
        "My favorite thing about summer is definitely",
    ]

    # Pick a random template
    base_prompt = random.choice(prompt_templates)

    if max_words < min_words:
        max_words = min_words
    target_words = random.randint(min_words, max_words)

    if target_words > 50:
        # For longer prompts, repeat context
        padding_text = (
            " This is an interesting topic that deserves more explanation. "
83
            # TODO: Update to * (target_words // 10) to better align with word ratio
84
85
            * (target_words // 50)
        )
86
        base_prompt = padding_text + base_prompt
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

    return base_prompt


def _extract_step_logprobs(request_output):
    if getattr(request_output, "outputs", None):
        inner = request_output.outputs[0]
        if hasattr(inner, "logprobs") and inner.logprobs is not None:
            t = torch.tensor(
                [
                    inner.logprobs[i][tid].logprob
                    for i, tid in enumerate(inner.token_ids)
                ],
                dtype=torch.float32,
            )
            return t, inner.token_ids

    return None, None
105
106
107
108


def is_device_capability_below_90() -> bool:
    return not current_platform.has_device_capability(90)