simple_frontend.py 2.2 KB
Newer Older
1
import time
Woosuk Kwon's avatar
Woosuk Kwon committed
2
from typing import List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5

from transformers import AutoTokenizer

Woosuk Kwon's avatar
Woosuk Kwon committed
6
from cacheflow.logger import init_logger
Woosuk Kwon's avatar
Woosuk Kwon committed
7
from cacheflow.sampling_params import SamplingParams
8
from cacheflow.sequence import Sequence, SequenceGroup
Woosuk Kwon's avatar
Woosuk Kwon committed
9
10
11
from cacheflow.utils import Counter


Woosuk Kwon's avatar
Woosuk Kwon committed
12
13
14
logger = init_logger(__name__)


15
class SimpleFrontend:
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17
18
19
20
21
22
23
24
25
26
27
28

    def __init__(
        self,
        model_name: str,
        block_size: int,
    ) -> None:
        self.block_size = block_size

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.seq_group_counter = Counter()
        self.seq_counter = Counter()
        self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = []

29
30
31
32
33
    def add_eos_token(self, sampling_params: SamplingParams) -> SamplingParams:
        # Stop generation when we see an EOS token.
        sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
        return sampling_params

Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
36
    def query(
        self,
        prompt: str,
37
        sampling_params: SamplingParams,
Woosuk Kwon's avatar
Woosuk Kwon committed
38
    ) -> None:
39
40
        token_ids = self.tokenizer.encode(prompt)
        self._add_query(token_ids, sampling_params)
Woosuk Kwon's avatar
Woosuk Kwon committed
41

42
43
44
45
    def _add_query(
        self,
        token_ids: List[int],
        sampling_params: SamplingParams,
46
        arrival_time: Optional[float] = None,
47
    ) -> None:
48
49
        if arrival_time is None:
            arrival_time = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
50
51
52
53
54
55
56
        seqs: List[Sequence] = []
        for _ in range(sampling_params.n):
            seq_id = next(self.seq_counter)
            seq = Sequence(seq_id, token_ids, block_size=self.block_size)
            seqs.append(seq)

        group_id = next(self.seq_group_counter)
57
        seq_group = SequenceGroup(group_id, seqs, arrival_time)
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        self.inputs.append((seq_group, sampling_params))

    def get_inputs(self) -> List[Tuple[SequenceGroup, SamplingParams]]:
        inputs = self.inputs
        self.inputs = []
        return inputs

    def print_response(
        self,
        seq_group: SequenceGroup,
    ) -> None:
        for seq in seq_group.seqs:
            token_ids = seq.get_token_ids()
            output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
Woosuk Kwon's avatar
Woosuk Kwon committed
72
            output = output.strip()
Woosuk Kwon's avatar
Woosuk Kwon committed
73
            logger.info(f"Seq {seq.seq_id}: {output!r}")