input_metadata.py 3.73 KB
Newer Older
1
from typing import Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3

import torch
Woosuk Kwon's avatar
Woosuk Kwon committed
4
from xformers.ops import AttentionBias
Woosuk Kwon's avatar
Woosuk Kwon committed
5

6
from vllm.sampling_params import SamplingParams, SamplingType
Woosuk Kwon's avatar
Woosuk Kwon committed
7
from vllm.sequence import SequenceData
8

Woosuk Kwon's avatar
Woosuk Kwon committed
9
10

class InputMetadata:
11
12
13
14
15
16
17
18
19
20
21
    """Metadata for input sequences. Used for PagedAttention.

    Args:
        seq_groups: List of (seq_ids, sampling_params).
        seq_data: Seq_id -> SequenceData.
        prompt_lens: Lengths of prompts.
        slot_mapping: The address to write the new KV to of each token.
        context_lens: the length of attention context for each generation token.
        max_context_len: The maximum context length.
        block_tables: The block tables. (Seq id -> list of physical block)
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
24

    def __init__(
        self,
25
26
        seq_groups: List[Tuple[List[int], SamplingParams]],
        seq_data: Dict[int, SequenceData],
Woosuk Kwon's avatar
Woosuk Kwon committed
27
28
29
30
31
        prompt_lens: List[int],
        slot_mapping: torch.Tensor,
        context_lens: torch.Tensor,
        max_context_len: int,
        block_tables: torch.Tensor,
32
33
        selected_token_indices: torch.Tensor,
        categorized_sample_indices: Dict[SamplingType, torch.Tensor],
34
        sliding_window: Optional[int] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
35
    ) -> None:
36
        self.seq_groups = seq_groups
37
        self.seq_data = seq_data
Woosuk Kwon's avatar
Woosuk Kwon committed
38
        self.prompt_lens = prompt_lens
Woosuk Kwon's avatar
Woosuk Kwon committed
39
        self.slot_mapping = slot_mapping
Woosuk Kwon's avatar
Woosuk Kwon committed
40
41
42
        self.context_lens = context_lens
        self.max_context_len = max_context_len
        self.block_tables = block_tables
43
44
        self.selected_token_indices = selected_token_indices
        self.categorized_sample_indices = categorized_sample_indices
Woosuk Kwon's avatar
Woosuk Kwon committed
45

46
        self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
47
48
49
50
        self.to_cache = None
        if sliding_window is not None:
            # We need to keep the positions of sliding windows within
            # the key / value tables, this is helpful to know which
51
            # elements we need to cache.
52
53
54
55
56
57
58
            to_cache, start_idx = [], 0
            for prompt_len in self.prompt_lens:
                to_cache.extend(
                    range(
                        start_idx + max(0, prompt_len - sliding_window),
                        start_idx + prompt_len,
                    ))
59
                start_idx += self.max_prompt_len
60
61
62
63
64
            to_cache.extend(range(start_idx, slot_mapping.shape[0]))
            self.to_cache = torch.tensor(to_cache,
                                         dtype=torch.int32,
                                         device=self.slot_mapping.device)

Woosuk Kwon's avatar
Woosuk Kwon committed
65
        self.num_prompts = len(prompt_lens)
66
        self.num_prompt_tokens = self.num_prompts * self.max_prompt_len
Woosuk Kwon's avatar
Woosuk Kwon committed
67
        self.num_generation_tokens = context_lens.shape[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
68
69
70
71
        if block_tables.numel() > 0:
            self.max_num_blocks_per_seq = block_tables.shape[1]
        else:
            self.max_num_blocks_per_seq = 0
72
        assert block_tables.shape[0] == self.num_generation_tokens
Woosuk Kwon's avatar
Woosuk Kwon committed
73

Woosuk Kwon's avatar
Woosuk Kwon committed
74
        # Set during the execution of the first attention op.
75
        self.attn_bias: Optional[AttentionBias] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
76

Woosuk Kwon's avatar
Woosuk Kwon committed
77
    def __repr__(self) -> str:
78
        # Print only useful metadata.
79
80
81
82
83
84
85
86
87
88
89
90
91
        return (
            f'InputMetadata('
            f'num_prompt_tokens={self.num_prompt_tokens}, '
            f'num_prompts={self.num_prompts}, '
            f'prompt_lens={self.prompt_lens}, '
            f'num_generation_tokens={self.num_generation_tokens}, '
            f'context_lens={self.context_lens}, '
            f'max_context_len={self.max_context_len}), '
            f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
            f'block_tables={self.block_tables}, '
            f'selected_token_indices={self.selected_token_indices}, '
            f'categorized_sample_indices={self.categorized_sample_indices}, '
            f'slot_mapping={self.slot_mapping})')