sample.py 1.43 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from typing import Dict, List, Tuple

import torch
import torch.nn as nn

from cacheflow.models import InputMetadata


class Sampler(nn.Module):

    def __init__(
        self,
        embedding: torch.Tensor,
    ) -> None:
        super().__init__()
        self.embedding = embedding.t()  # [hidden_size, vocab_size]

    def forward(
        self,
        hidden_states: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> Dict[int, Tuple[int, int]]:
        # Get the hidden states of the last tokens.
        start_idx = 0
        last_token_indicies: List[int] = []
        for prompt_len in input_metadata.prompt_lens:
            last_token_indicies.append(start_idx + prompt_len - 1)
            start_idx += prompt_len
        last_token_indicies.extend(
            range(start_idx, start_idx + input_metadata.num_generation_tokens))
        hidden_states = hidden_states[last_token_indicies]

        # Get the logits for the next tokens.
        logits = torch.matmul(hidden_states, self.embedding)

        # Sample the next tokens.
        # TODO(woosuk): Implement other sampling methods.
        next_token_ids = torch.argmax(logits, dim=-1)
        next_token_ids = next_token_ids.tolist()

        # Return the next tokens.
        next_tokens: Dict[int, Tuple[int, int]] = {}
        for seq_id, token_id in zip(input_metadata.seq_ids, next_token_ids):
            next_tokens[seq_id] = (seq_id, token_id)
        return next_tokens