sequence.py 2.63 KB
Newer Older
chenzk's avatar
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from dataclasses import dataclass, field
from enum import Enum, auto
from itertools import count
from typing import List

from vllm.kvprune.compression.compression_config import SequenceCompressionParams
from vllm.kvprune.config.sampling_params import SamplingParams


class SequenceStatus(Enum):
    WAITING = auto()
    RUNNING = auto()
    FINISHED = auto()


@dataclass
class Sequence:
    """
    Represents a single user request / sequence being generated.
    """

    _counter = count()

    prompt_token_ids: List[int]
    completion_token_ids: List[int] = field(default_factory=list)
    sampling_params: SamplingParams = field(default_factory=SamplingParams)
    compression_params: SequenceCompressionParams = field(
        default_factory=SequenceCompressionParams
    )
    status: SequenceStatus = SequenceStatus.WAITING

    seq_id: int = field(default_factory=lambda: next(Sequence._counter), init=False)
    num_tokens_processed: int = 0

    @property
    def num_prompt_tokens(self) -> int:
        return len(self.prompt_token_ids)

    @property
    def num_generated_tokens(self) -> int:
        return len(self.completion_token_ids)

    def add_new_token(self, token_id: int) -> None:
        if len(self.completion_token_ids) == 0:
            self.num_tokens_processed += self.num_prompt_tokens
        self.completion_token_ids.append(token_id)
        self.num_tokens_processed += 1

    def tokens_to_retain_per_layer(self, num_kv_heads: int) -> int:
        n = int(
            self.compression_params.compression_ratio
            * self.num_prompt_tokens
            * num_kv_heads
        )
        return max(1, n)

    def __getstate__(self):
        return dict(
            prompt_token_ids=list(self.prompt_token_ids),
            completion_token_ids=list(self.completion_token_ids),
            sampling_params=self.sampling_params,
            compression_params=self.compression_params,
            status=self.status,
            seq_id=self.seq_id,
            num_tokens_processed=self.num_tokens_processed,
        )

    def __setstate__(self, state):
        self.prompt_token_ids = list(state["prompt_token_ids"])
        self.completion_token_ids = list(state["completion_token_ids"])
        self.sampling_params = state["sampling_params"]
        self.compression_params = state["compression_params"]
        self.status = state["status"]
        self.seq_id = state["seq_id"]
        self.num_tokens_processed = state["num_tokens_processed"]

    @property
    def prompt_len(self) -> int:
        return len(self.prompt_token_ids)

    @property
    def completion_len(self) -> int:
        return len(self.completion_token_ids)