test_guided_processors.py 3.16 KB
Newer Older
1
2
# This unit test should be moved to a new
# tests/test_guided_decoding directory.
3
import pytest
4
import os
5
import torch
6
from transformers import AutoTokenizer
7

8
9
10
11
12
from vllm.entrypoints.openai.protocol import CompletionRequest
from vllm.model_executor.guided_decoding import (
    get_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
    JSONLogitsProcessor, RegexLogitsProcessor)
13
from ...utils import models_path_prefix
14
15


16
def test_guided_logits_processors(sample_regex, sample_json_schema):
17
    """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
18
    tokenizer = AutoTokenizer.from_pretrained(os.path.join(models_path_prefix, 'HuggingFaceH4/zephyr-7b-beta'))
19
20
    regex_LP = RegexLogitsProcessor(sample_regex, tokenizer)
    json_LP = JSONLogitsProcessor(sample_json_schema,
21
22
                                  tokenizer,
                                  whitespace_pattern=None)
23
24

    token_ids = tokenizer.encode(
25
        f"Give an example IPv4 address with this regex: {sample_regex}")
26
27
28
29
30
31
32
    tensor = torch.rand(32000)
    original_tensor = torch.clone(tensor)
    regex_LP(token_ids, tensor)
    assert tensor.shape == original_tensor.shape
    assert not torch.allclose(tensor, original_tensor)

    token_ids = tokenizer.encode(
33
34
        f"Give an employee profile that fits this schema: {sample_json_schema}"
    )
35
36
37
38
39
    tensor = torch.rand(32000)
    original_tensor = torch.clone(tensor)
    json_LP(token_ids, tensor)
    assert tensor.shape == original_tensor.shape
    assert not torch.allclose(tensor, original_tensor)
40
41
42
43


@pytest.mark.asyncio
@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"])
44
45
async def test_guided_logits_processor_black_box(backend: str, sample_regex,
                                                 sample_json_schema):
46
    tokenizer = AutoTokenizer.from_pretrained(os.path.join(models_path_prefix, 'HuggingFaceH4/zephyr-7b-beta'))
47
    token_ids = tokenizer.encode(
48
        f"Give an example IPv4 address with this regex: {sample_regex}")
49
50
    regex_request = CompletionRequest(model='test',
                                      prompt=token_ids,
51
                                      guided_regex=sample_regex)
52
53
54
55
56
57
58
59
60
61
    regex_lp = await get_guided_decoding_logits_processor(
        backend, regex_request, tokenizer)
    assert regex_lp is not None
    tensor = torch.rand(32000)
    original_tensor = torch.clone(tensor)
    tensor = regex_lp(token_ids, tensor)
    assert tensor.shape == original_tensor.shape
    assert not torch.allclose(tensor, original_tensor)

    token_ids = tokenizer.encode(
62
63
        f"Give an employee profile that fits this schema: {sample_json_schema}"
    )
64
65
    json_request = CompletionRequest(model='test',
                                     prompt=token_ids,
66
                                     guided_json=sample_json_schema)
67
68
69
70
71
72
73
74
    json_lp = await get_guided_decoding_logits_processor(
        backend, json_request, tokenizer)
    assert json_lp is not None
    tensor = torch.rand(32000)
    original_tensor = torch.clone(tensor)
    tensor = json_lp(token_ids, tensor)
    assert tensor.shape == original_tensor.shape
    assert not torch.allclose(tensor, original_tensor)