test_tokens.py 3 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
import torch
2
from text_generation_server.utils.tokens import (
3
4
    StopSequenceCriteria,
    StoppingCriteria,
5
    FinishReason,
Nicolas Patry's avatar
Nicolas Patry committed
6
    batch_top_tokens,
7
8
9
)


10
def test_stop_sequence_criteria():
11
    criteria = StopSequenceCriteria("/test;")
12

13
14
15
16
    assert not criteria("/")
    assert not criteria("/test")
    assert criteria("/test;")
    assert not criteria("/test; ")
17
18


19
20
21
22
23
24
25
26
27
def test_stop_sequence_criteria_escape():
    criteria = StopSequenceCriteria("<|stop|>")

    assert not criteria("<")
    assert not criteria("<|stop")
    assert criteria("<|stop|>")
    assert not criteria("<|stop|> ")


28
29
30
def test_stopping_criteria():
    criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
    assert criteria(65827, "/test") == (False, None)
31
    assert criteria(30, ";") == (True, FinishReason.FINISH_REASON_STOP_SEQUENCE)
32
33


34
35
36
def test_stopping_criteria_eos():
    criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
    assert criteria(1, "") == (False, None)
37
    assert criteria(0, "") == (True, FinishReason.FINISH_REASON_EOS_TOKEN)
38
39
40


def test_stopping_criteria_max():
41
42
43
44
45
    criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
    assert criteria(1, "") == (False, None)
    assert criteria(1, "") == (False, None)
    assert criteria(1, "") == (False, None)
    assert criteria(1, "") == (False, None)
46
    assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
Nicolas Patry's avatar
Nicolas Patry committed
47

OlivierDehaene's avatar
OlivierDehaene committed
48

Nicolas Patry's avatar
Nicolas Patry committed
49
50
51
def test_batch_top_tokens():
    top_n_tokens = [0, 2, 3, 4, 5]
    top_n_tokens_tensor = torch.tensor(top_n_tokens)
OlivierDehaene's avatar
OlivierDehaene committed
52
    inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5)
Nicolas Patry's avatar
Nicolas Patry committed
53
    accepted_ids = torch.ones_like(top_n_tokens_tensor)
Nicolas Patry's avatar
Nicolas Patry committed
54

OlivierDehaene's avatar
OlivierDehaene committed
55
    topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
Nicolas Patry's avatar
Nicolas Patry committed
56
        top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
OlivierDehaene's avatar
OlivierDehaene committed
57
    )
Nicolas Patry's avatar
Nicolas Patry committed
58

Nicolas Patry's avatar
Nicolas Patry committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    assert topn_tok_ids[0] == [[]]
    assert topn_tok_ids[1] == [[0, 3]]
    assert topn_tok_ids[2] == [[0, 3, 1, 4]]
    assert topn_tok_ids[3] == [[0, 3, 1, 4]]
    assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]

    assert topn_tok_logprobs[0] == [[]]
    assert topn_tok_logprobs[1] == [[-1, -2]]
    assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]
    assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]
    assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]

    # Now let's make second member of the batch be speculated
    inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5 * 2)
    accepted_ids[1]  = 2
    topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
        top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
    )

    assert topn_tok_ids[0] == [[]]
    assert topn_tok_ids[1] == [[0, 3], [0, 3]]
    assert topn_tok_ids[2] == [[0, 3, 1, 4]]
    assert topn_tok_ids[3] == [[0, 3, 1, 4]]
    assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]

    assert topn_tok_logprobs[0] == [[]]
    assert topn_tok_logprobs[1] == [[-1, -2], [-1, -2]]
    assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]
    assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]
    assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]