test_tokens.py 3 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
import torch
2
from text_generation_server.utils.tokens import (
3
4
    StopSequenceCriteria,
    StoppingCriteria,
5
    FinishReason,
Nicolas Patry's avatar
Nicolas Patry committed
6
    batch_top_tokens,
7
8
9
)


10
def test_stop_sequence_criteria():
11
    criteria = StopSequenceCriteria("/test;")
12

13
14
15
16
    assert not criteria("/")
    assert not criteria("/test")
    assert criteria("/test;")
    assert not criteria("/test; ")
17
18


19
20
21
22
23
24
25
26
27
def test_stop_sequence_criteria_escape():
    criteria = StopSequenceCriteria("<|stop|>")

    assert not criteria("<")
    assert not criteria("<|stop")
    assert criteria("<|stop|>")
    assert not criteria("<|stop|> ")


28
29
30
def test_stopping_criteria():
    criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
    assert criteria(65827, "/test") == (False, None)
31
    assert criteria(30, ";") == (True, FinishReason.FINISH_REASON_STOP_SEQUENCE)
32
33


34
35
36
def test_stopping_criteria_eos():
    criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
    assert criteria(1, "") == (False, None)
37
    assert criteria(0, "") == (True, FinishReason.FINISH_REASON_EOS_TOKEN)
38
39
40


def test_stopping_criteria_max():
41
42
43
44
45
    criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
    assert criteria(1, "") == (False, None)
    assert criteria(1, "") == (False, None)
    assert criteria(1, "") == (False, None)
    assert criteria(1, "") == (False, None)
46
    assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
Nicolas Patry's avatar
Nicolas Patry committed
47

OlivierDehaene's avatar
OlivierDehaene committed
48

Nicolas Patry's avatar
Nicolas Patry committed
49
50
51
def test_batch_top_tokens():
    top_n_tokens = [0, 2, 3, 4, 5]
    top_n_tokens_tensor = torch.tensor(top_n_tokens)
OlivierDehaene's avatar
OlivierDehaene committed
52
    inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5)
Nicolas Patry's avatar
Nicolas Patry committed
53
    accepted_ids = torch.ones_like(top_n_tokens_tensor)
Nicolas Patry's avatar
Nicolas Patry committed
54

OlivierDehaene's avatar
OlivierDehaene committed
55
    topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
Nicolas Patry's avatar
Nicolas Patry committed
56
        top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
OlivierDehaene's avatar
OlivierDehaene committed
57
    )
Nicolas Patry's avatar
Nicolas Patry committed
58

Nicolas Patry's avatar
Nicolas Patry committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    assert topn_tok_ids[0] == [[]]
    assert topn_tok_ids[1] == [[0, 3]]
    assert topn_tok_ids[2] == [[0, 3, 1, 4]]
    assert topn_tok_ids[3] == [[0, 3, 1, 4]]
    assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]

    assert topn_tok_logprobs[0] == [[]]
    assert topn_tok_logprobs[1] == [[-1, -2]]
    assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]
    assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]
    assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]

    # Now let's make second member of the batch be speculated
    inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5 * 2)
73
    accepted_ids[1] = 2
Nicolas Patry's avatar
Nicolas Patry committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
        top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
    )

    assert topn_tok_ids[0] == [[]]
    assert topn_tok_ids[1] == [[0, 3], [0, 3]]
    assert topn_tok_ids[2] == [[0, 3, 1, 4]]
    assert topn_tok_ids[3] == [[0, 3, 1, 4]]
    assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]

    assert topn_tok_logprobs[0] == [[]]
    assert topn_tok_logprobs[1] == [[-1, -2], [-1, -2]]
    assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]
    assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]
    assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]