test_eagle3.py 1.51 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch

6
7
from vllm.model_executor.models.interfaces import supports_eagle3

8
9
10

@pytest.mark.parametrize(
    "model_path",
11
    [("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
12
13
14
15
def test_llama(vllm_runner, example_prompts, model_path, monkeypatch):
    # Set environment variable for V1 engine serialization
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")

16
    with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
17
18
19
        eagle3_supported = vllm_model.apply_model(supports_eagle3)
        assert eagle3_supported

20
21
22
23
        vllm_outputs = vllm_model.generate_greedy(example_prompts,
                                                  max_tokens=20)
        print(vllm_outputs)
        assert vllm_outputs
24
25
26
27
28


@pytest.mark.parametrize(
    "model_path",
    [("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")])
29
30
31
32
def test_qwen(vllm_runner, example_prompts, model_path, monkeypatch):
    # Set environment variable for V1 engine serialization
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")

33
    with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
34
35
36
        eagle3_supported = vllm_model.apply_model(supports_eagle3)
        assert eagle3_supported

37
38
39
40
        vllm_outputs = vllm_model.generate_greedy(example_prompts,
                                                  max_tokens=20)
        print(vllm_outputs)
        assert vllm_outputs