"vllm/vscode:/vscode.git/clone" did not exist on "b82fc1364d313a222bde7e9ac897fd847ad7b05b"
sequence.py 2.44 KB
Newer Older
lizhigong's avatar
lizhigong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60


from typing import Union
from vllm.sequence import Sequence
from typing import Sequence as GenericSequence


class ZeroOverheadSequence(Sequence):
    def __init__(self, seq_id, inputs, block_size, eos_token_id = None, lora_request = None, prompt_adapter_request = None):
        super().__init__(seq_id, inputs, block_size, eos_token_id, lora_request, prompt_adapter_request)
        self.effective_output_len : int = 0

    def fix_last_token_id(self, token_id: int) -> None:
        effect_offset = self.effective_output_len - len(self.data.output_token_ids)
        assert effect_offset < 0
        self.data._output_token_ids[effect_offset] = token_id
        if len(self.data._new_appended_tokens) >= effect_offset * -1:
            self.data._new_appended_tokens[effect_offset] = token_id
        self.data._cached_all_token_ids[effect_offset] = token_id
        self.effective_output_len += 1
    

    def zero_overhead_get_output_token_ids(self) -> tuple[int, ...]:
        return self.data.output_token_ids[:self.effective_output_len]
    
    def zero_overhead_get_output_len(self) -> int:
        return self.effective_output_len
    
    def zero_overhead_get_last_token_id(self) -> int:
        if self.effective_output_len == 0:
            return self.data._prompt_token_ids[-1]
        return self.data._output_token_ids[self.effective_output_len - 1]
    
    def zero_overhead_get_len(self) -> int:
        return self.effective_output_len + len(self.data._prompt_token_ids)
    
    def get_output_token_ids_to_return(
            self, delta: bool) -> Union[GenericSequence[int], int]:
        """If delta is True, only new tokens since the last call to
        this method are returned"""
        if not delta:
            return self.zero_overhead_get_output_token_ids()

        output_len = self.zero_overhead_get_output_len()

        # Get the number of new tokens
        num_new_tokens = output_len - self._last_output_token_ids_offset
        self._last_output_token_ids_offset = output_len

        # Return new tokens
        if num_new_tokens == 1:
            # Optimization for single decode token case
            # (which is what we have most of the time)
            return self.data._cached_all_token_ids[self.effective_output_len - 1]

        if num_new_tokens == 0:
            return []

        effect_offset = self.effective_output_len - len(self.data.output_token_ids)
        return self.data._cached_all_token_ids[-num_new_tokens : effect_offset]