"tests/vscode:/vscode.git/clone" did not exist on "d7219bcda3e6508cb14881bec303e2d0ab68c898"
sequence.py 2.78 KB
Newer Older
lizhigong's avatar
lizhigong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21


from typing import Union
from vllm.sequence import Sequence
from typing import Sequence as GenericSequence


class ZeroOverheadSequence(Sequence):
    def __init__(self, seq_id, inputs, block_size, eos_token_id = None, lora_request = None, prompt_adapter_request = None):
        super().__init__(seq_id, inputs, block_size, eos_token_id, lora_request, prompt_adapter_request)
        self.effective_output_len : int = 0

    def fix_last_token_id(self, token_id: int) -> None:
        effect_offset = self.effective_output_len - len(self.data.output_token_ids)
        assert effect_offset < 0
        self.data._output_token_ids[effect_offset] = token_id
        if len(self.data._new_appended_tokens) >= effect_offset * -1:
            self.data._new_appended_tokens[effect_offset] = token_id
        self.data._cached_all_token_ids[effect_offset] = token_id
        self.effective_output_len += 1
    
lizhigong's avatar
lizhigong committed
22
23
24
25
26
    def remove_last_place_holder(self, count):
        self.data._output_token_ids = self.data._output_token_ids[:-1 * count]
        self.data._new_appended_tokens = self.data._new_appended_tokens[:-1 * count]
        self.data._cached_all_token_ids = self.data._cached_all_token_ids[:-1 * count]
        self.data._num_computed_tokens -= count
lizhigong's avatar
lizhigong committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

    def zero_overhead_get_output_token_ids(self) -> tuple[int, ...]:
        return self.data.output_token_ids[:self.effective_output_len]
    
    def zero_overhead_get_output_len(self) -> int:
        return self.effective_output_len
    
    def zero_overhead_get_last_token_id(self) -> int:
        if self.effective_output_len == 0:
            return self.data._prompt_token_ids[-1]
        return self.data._output_token_ids[self.effective_output_len - 1]
    
    def zero_overhead_get_len(self) -> int:
        return self.effective_output_len + len(self.data._prompt_token_ids)
    
    def get_output_token_ids_to_return(
            self, delta: bool) -> Union[GenericSequence[int], int]:
        """If delta is True, only new tokens since the last call to
        this method are returned"""
        if not delta:
            return self.zero_overhead_get_output_token_ids()

        output_len = self.zero_overhead_get_output_len()

        # Get the number of new tokens
        num_new_tokens = output_len - self._last_output_token_ids_offset
        self._last_output_token_ids_offset = output_len

        # Return new tokens
        if num_new_tokens == 1:
            # Optimization for single decode token case
            # (which is what we have most of the time)
            return self.data._cached_all_token_ids[self.effective_output_len - 1]

        if num_new_tokens == 0:
            return []

        effect_offset = self.effective_output_len - len(self.data.output_token_ids)
        return self.data._cached_all_token_ids[-num_new_tokens : effect_offset]