test_eagle3.py 1006 Bytes
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch


@pytest.mark.parametrize(
    "model_path",
9
    [("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
10
11
12
13
14
15
def test_llama(vllm_runner, example_prompts, model_path):
    with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
        vllm_outputs = vllm_model.generate_greedy(example_prompts,
                                                  max_tokens=20)
        print(vllm_outputs)
        assert vllm_outputs
16
17
18
19
20
21
22
23
24
25
26


@pytest.mark.parametrize(
    "model_path",
    [("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")])
def test_qwen(vllm_runner, example_prompts, model_path):
    with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
        vllm_outputs = vllm_model.generate_greedy(example_prompts,
                                                  max_tokens=20)
        print(vllm_outputs)
        assert vllm_outputs