test_eagle_dp.py 2.42 KB
Newer Older
Rémi Delacourt's avatar
Rémi Delacourt committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
from contextlib import AsyncExitStack
from dataclasses import replace

import pytest

from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM

DP_SIZE = int(os.getenv("DP_SIZE", 2))


@pytest.mark.asyncio
async def test_run_eagle_dp():
    target_model = "meta-llama/Llama-3.1-8B-Instruct"
    draft_model = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"

    engine_args = AsyncEngineArgs(
        model=target_model,
        tokenizer_mode="auto",
        enforce_eager=False,
        tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
        data_parallel_size=DP_SIZE,
        data_parallel_backend="mp",  # ray takes more time
        trust_remote_code=True,
        max_model_len=16384,
    )

    eagle_engine_args = replace(
        engine_args,
        speculative_config={
            "model": draft_model,
            "method": "eagle",
            "num_speculative_tokens": 3,
        },
    )

    prompt = "This is a test of data parallel with eagle"
    num_expected_tokens = 100
    sampling_params = SamplingParams(
        min_tokens=num_expected_tokens,
        max_tokens=num_expected_tokens,
        ignore_eos=True,
        output_kind=RequestOutputKind.FINAL_ONLY,
        temperature=0,
    )

    async def generate_with_timeout(given_engine: AsyncLLM):
        async for out in given_engine.generate(
            request_id="test-eagle-dp", prompt=prompt, sampling_params=sampling_params
        ):
            token_ids = out.outputs[0].token_ids
            assert len(token_ids) == num_expected_tokens
            return token_ids

    async def engine_create_and_generate(engine_args: AsyncEngineArgs):
        async with AsyncExitStack() as after:
            engine = AsyncLLM.from_engine_args(engine_args)
            after.callback(engine.shutdown)

            token_ids = await asyncio.wait_for(
                generate_with_timeout(engine), timeout=30
            )

            assert not engine.output_processor.has_unfinished_requests()
        return token_ids

    token_ids_with_eagle = await engine_create_and_generate(eagle_engine_args)
    token_ids_no_eagle = await engine_create_and_generate(engine_args)

    # Test for correctness
    assert token_ids_with_eagle == token_ids_no_eagle