test_token_classification.py 5.12 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
import random

import numpy as np
6
7
8
9
10
import pytest
import torch
from transformers import AutoModelForTokenClassification

from tests.models.utils import softmax
11
from vllm.platforms import current_platform
12
13


14
15
16
17
18
19
20
21
22
23
24
25
26
27
@pytest.fixture(autouse=True)
def seed_everything():
    """Seed all random number generators for reproducibility."""
    seed = 0
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    yield


28
29
30
31
@pytest.mark.parametrize("model", ["boltuix/NeuroBERT-NER"])
# The float32 is required for this tiny model to pass the test.
@pytest.mark.parametrize("dtype", ["float"])
@torch.inference_mode
32
33
34
35
36
37
38
39
def test_bert_models(
    hf_runner,
    vllm_runner,
    example_prompts,
    model: str,
    dtype: str,
) -> None:
    with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
40
        vllm_outputs = vllm_model.token_classify(example_prompts)
41

42
43
44
45
46
47
    # Use eager attention on ROCm to avoid HF Transformers flash attention
    # accuracy issues: https://github.com/vllm-project/vllm/issues/30167
    hf_model_kwargs = {}
    if current_platform.is_rocm():
        hf_model_kwargs["attn_implementation"] = "eager"

48
    with hf_runner(
49
50
51
52
        model,
        dtype=dtype,
        auto_cls=AutoModelForTokenClassification,
        model_kwargs=hf_model_kwargs,
53
54
55
56
57
58
59
60
61
62
63
    ) as hf_model:
        tokenizer = hf_model.tokenizer
        hf_outputs = []
        for prompt in example_prompts:
            inputs = tokenizer([prompt], return_tensors="pt")
            inputs = hf_model.wrap_device(inputs)
            output = hf_model.model(**inputs)
            hf_outputs.append(softmax(output.logits[0]))

    # check logits difference
    for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
64
65
        hf_output = hf_output.detach().clone().cpu().float()
        vllm_output = vllm_output.detach().clone().cpu().float()
66
        torch.testing.assert_close(hf_output, vllm_output, atol=3.2e-2, rtol=1e-3)
67
68
69
70


@pytest.mark.parametrize("model", ["disham993/electrical-ner-ModernBERT-base"])
@pytest.mark.parametrize("dtype", ["float"])
71
@pytest.mark.flaky(reruns=3)
72
73
@torch.inference_mode
def test_modernbert_models(
74
75
76
77
78
79
    hf_runner,
    vllm_runner,
    example_prompts,
    model: str,
    dtype: str,
) -> None:
80
81
82
83
84
85
86
87
88
    # NOTE: https://github.com/vllm-project/vllm/pull/32403
    # `disham993/electrical-ner-ModernBERT-base` is a randomly initialized
    # model, which can cause numerical precision variance and edge cases.
    # We use @flaky(reruns=3) to mitigate intermittent failures.
    print(
        f"\n[NOTE] Testing {model} (randomly initialized weights) - "
        "flaky tolerance enabled due to numerical precision variance."
    )

89
    with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
90
        vllm_outputs = vllm_model.token_classify(example_prompts)
91

92
93
94
95
96
97
    # Use eager attention on ROCm to avoid HF Transformers flash attention
    # accuracy issues: https://github.com/vllm-project/vllm/issues/30167
    hf_model_kwargs = {}
    if current_platform.is_rocm():
        hf_model_kwargs["attn_implementation"] = "eager"

98
    with hf_runner(
99
100
101
102
        model,
        dtype=dtype,
        auto_cls=AutoModelForTokenClassification,
        model_kwargs=hf_model_kwargs,
103
    ) as hf_model:
104
105
106
107
108
109
110
111
112
113
        tokenizer = hf_model.tokenizer
        hf_outputs = []
        for prompt in example_prompts:
            inputs = tokenizer([prompt], return_tensors="pt")
            inputs = hf_model.wrap_device(inputs)
            output = hf_model.model(**inputs)
            hf_outputs.append(softmax(output.logits[0]))

    # check logits difference
    for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
114
115
        hf_output = hf_output.detach().clone().cpu().float()
        vllm_output = vllm_output.detach().clone().cpu().float()
116
        torch.testing.assert_close(hf_output, vllm_output, atol=3.2e-2, rtol=1e-3)
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144


@pytest.mark.parametrize("model", ["bd2lcco/Qwen3-0.6B-finetuned"])
@pytest.mark.parametrize("dtype", ["float"])
@torch.inference_mode
def test_auto_conversion(
    hf_runner,
    vllm_runner,
    example_prompts,
    model: str,
    dtype: str,
) -> None:
    with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model:
        vllm_outputs = vllm_model.token_classify(example_prompts)

    with hf_runner(
        model, dtype=dtype, auto_cls=AutoModelForTokenClassification
    ) as hf_model:
        tokenizer = hf_model.tokenizer
        hf_outputs = []
        for prompt in example_prompts:
            inputs = tokenizer([prompt], return_tensors="pt")
            inputs = hf_model.wrap_device(inputs)
            output = hf_model.model(**inputs)
            hf_outputs.append(softmax(output.logits[0]))

    # check logits difference
    for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
145
146
        hf_output = hf_output.detach().clone().cpu().float()
        vllm_output = vllm_output.detach().clone().cpu().float()
147
        assert torch.allclose(hf_output, vllm_output, atol=1e-2)