states.py 5.21 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch

6
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
9
10
11
12
13
14


class RequestState:
    def __init__(
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_batched_tokens: int,
15
        num_speculative_steps: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17
        vocab_size: int,
        device: torch.device,
18
19
        model_dtype: torch.dtype,
        cache_draft_logits: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
20
21
22
23
    ):
        self.max_num_reqs = max_num_reqs
        self.max_model_len = max_model_len
        self.max_num_batched_tokens = max_num_batched_tokens
24
        self.num_speculative_steps = num_speculative_steps
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
29
30
31
        self.vocab_size = vocab_size
        self.device = device

        self.req_id_to_index: dict[str, int] = {}
        self.index_to_req_id: dict[int, str] = {}
        self.free_indices = list(range(max_num_reqs))

32
33
        # NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
        # depending on the configured max_num_reqs and max_model_len.
34
        # To save GPU memory, we use UVA instead of GPU for this tensor.
35
        self.all_token_ids = StagedWriteTensor(
36
37
38
39
            (self.max_num_reqs, self.max_model_len),
            dtype=torch.int32,
            device=device,
            uva_instead_of_gpu=True,
Woosuk Kwon's avatar
Woosuk Kwon committed
40
        )
41
42
43
44
45
46
47
48
49
50
        # NOTE(woosuk): Distinguish clearly between prompt_len and prefill_len:
        # - prompt_len: Number of tokens in the user-provided prompt.
        # - prefill_len: Number of tokens passed into the model runner.
        #   This can include the prompt and additional partial output tokens,
        #   so prefill_len >= prompt_len.
        # Usually, prefill_len equals prompt_len, but in cases such as resumption after
        # preemption, prefill_len may be greater. Differentiating between these values
        # is crucial, as certain features such as prompt logprobs or frequency penalties
        # must treat prompt and output tokens separately.
        self.prompt_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
51
        self.prefill_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
52
53
54
55
        # total_len = prompt_len + output_len. It grows as the request progresses.
        self.total_len = StagedWriteTensor(
            self.max_num_reqs, dtype=torch.int32, device=device
        )
56

57
58
        # Number of computed tokens.
        self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
59
        self.num_computed_tokens = StagedWriteTensor(
60
61
            self.max_num_reqs, dtype=torch.int32, device=device
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
62
63
64

        # Last sampled tokens.
        self.last_sampled_tokens = torch.zeros(
65
            self.max_num_reqs, 1, dtype=torch.int64, device=device
Woosuk Kwon's avatar
Woosuk Kwon committed
66
67
        )

68
69
70
71
72
73
74
        # Draft tokens.
        self.draft_tokens = torch.zeros(
            self.max_num_reqs,
            self.num_speculative_steps,
            dtype=torch.int64,
            device=device,
        )
75
76
77
78
79
80
81
82
83
84
85
86
87
        # Draft token logits.
        # NOTE: This tensor maintains the "processed" logits after applying temperature,
        # top-p, etc.
        self.draft_logits: torch.Tensor | None = None
        if cache_draft_logits:
            self.draft_logits = torch.zeros(
                self.max_num_reqs,
                self.num_speculative_steps,
                self.vocab_size,
                dtype=model_dtype,
                device=device,
            )

88
89
90
        self.next_prefill_tokens = torch.zeros(
            self.max_num_reqs, dtype=torch.int32, device=device
        )
91

Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
94
95
96
97
98
99
    @property
    def num_reqs(self) -> int:
        return len(self.req_id_to_index)

    def add_request(
        self,
        req_id: str,
        prompt_len: int,
100
        all_token_ids: list[int],
Woosuk Kwon's avatar
Woosuk Kwon committed
101
102
103
104
105
106
107
        num_computed_tokens: int,
    ) -> None:
        assert len(self.free_indices) > 0, "No free indices"
        req_idx = self.free_indices.pop()
        self.req_id_to_index[req_id] = req_idx
        self.index_to_req_id[req_idx] = req_id

108
109
        self.prompt_len.np[req_idx] = prompt_len
        prefill_len = len(all_token_ids)
Woosuk Kwon's avatar
Woosuk Kwon committed
110
111
112
113
        assert prefill_len >= prompt_len, (
            f"prefill_len {prefill_len} < prompt_len {prompt_len}"
        )
        self.prefill_len.np[req_idx] = prefill_len
114
115
        self.total_len.stage_write_elem(req_idx, prefill_len)
        self.all_token_ids.stage_write(req_idx, 0, all_token_ids)
116
        self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
117
        self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
118

119
    def apply_staged_writes(self) -> None:
120
        self.prompt_len.copy_to_uva()
121
        self.prefill_len.copy_to_uva()
122
123
        self.total_len.apply_write()
        self.all_token_ids.apply_write()
124
125
        self.num_computed_tokens.apply_write()

Woosuk Kwon's avatar
Woosuk Kwon committed
126
127
128
129
130
131
132
    def remove_request(self, req_id: str) -> None:
        req_idx = self.req_id_to_index.pop(req_id, None)
        if req_idx is None:
            # Request not found.
            return
        self.index_to_req_id.pop(req_idx, None)
        self.free_indices.append(req_idx)
133
134
135
136
137
138

    def any_prefills(self, idx_mapping_np: np.ndarray) -> bool:
        return np.any(
            self.num_computed_prefill_tokens[idx_mapping_np]
            < self.prefill_len.np[idx_mapping_np]
        )