test_get_eos.py 1.32 KB
Newer Older
1
2
3
4
5
6
7
"""
This test file includes some cases where it is inappropriate to
only get the `eos_token_id` from the tokenizer as defined by
:meth:`vllm.LLMEngine._get_eos_token_id`.
"""
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer import get_tokenizer
8
9
from ..utils import models_path_prefix
import os
10
11
import pytest
from vllm.utils import is_hip
12
13


14
15
@pytest.mark.skipif(is_hip(),
                    reason="Consistent with NV.")
16
def test_get_llama3_eos_token():
17
    model_name = os.path.join(models_path_prefix, "meta-llama/Meta-Llama-3-8B-Instruct")
18
19
20
21
22
23
24
25
26
27
28

    tokenizer = get_tokenizer(model_name)
    assert tokenizer.eos_token_id == 128009

    generation_config = try_get_generation_config(model_name,
                                                  trust_remote_code=False)
    assert generation_config is not None
    assert generation_config.eos_token_id == [128001, 128009]


def test_get_blip2_eos_token():
29
    model_name = os.path.join(models_path_prefix, "Salesforce/blip2-opt-2.7b")
30
31
32
33
34
35
36
37

    tokenizer = get_tokenizer(model_name)
    assert tokenizer.eos_token_id == 2

    generation_config = try_get_generation_config(model_name,
                                                  trust_remote_code=False)
    assert generation_config is not None
    assert generation_config.eos_token_id == 50118