states.py 5.42 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch

6
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
9
10
11
12
13
14


class RequestState:
    def __init__(
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_batched_tokens: int,
15
        num_speculative_steps: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17
18
19
20
21
        vocab_size: int,
        device: torch.device,
    ):
        self.max_num_reqs = max_num_reqs
        self.max_model_len = max_model_len
        self.max_num_batched_tokens = max_num_batched_tokens
22
        self.num_speculative_steps = num_speculative_steps
Woosuk Kwon's avatar
Woosuk Kwon committed
23
24
25
26
27
28
29
        self.vocab_size = vocab_size
        self.device = device

        self.req_id_to_index: dict[str, int] = {}
        self.index_to_req_id: dict[int, str] = {}
        self.free_indices = list(range(max_num_reqs))

30
31
        # NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
        # depending on the configured max_num_reqs and max_model_len.
32
        # To save GPU memory, we use UVA instead of GPU for this tensor.
33
        self.all_token_ids = StagedWriteTensor(
34
35
36
37
            (self.max_num_reqs, self.max_model_len),
            dtype=torch.int32,
            device=device,
            uva_instead_of_gpu=True,
Woosuk Kwon's avatar
Woosuk Kwon committed
38
        )
39
40
41
42
43
44
45
46
47
48
        # NOTE(woosuk): Distinguish clearly between prompt_len and prefill_len:
        # - prompt_len: Number of tokens in the user-provided prompt.
        # - prefill_len: Number of tokens passed into the model runner.
        #   This can include the prompt and additional partial output tokens,
        #   so prefill_len >= prompt_len.
        # Usually, prefill_len equals prompt_len, but in cases such as resumption after
        # preemption, prefill_len may be greater. Differentiating between these values
        # is crucial, as certain features such as prompt logprobs or frequency penalties
        # must treat prompt and output tokens separately.
        self.prompt_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
49
        self.prefill_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
50
51
52
53
        # total_len = prompt_len + output_len. It grows as the request progresses.
        self.total_len = StagedWriteTensor(
            self.max_num_reqs, dtype=torch.int32, device=device
        )
54

55
56
        # Number of computed tokens.
        self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
57
        self.num_computed_tokens = StagedWriteTensor(
58
59
            self.max_num_reqs, dtype=torch.int32, device=device
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
62

        # Last sampled tokens.
        self.last_sampled_tokens = torch.zeros(
63
            self.max_num_reqs, 1, dtype=torch.int64, device=device
Woosuk Kwon's avatar
Woosuk Kwon committed
64
65
        )

66
67
68
69
70
71
72
        # Draft tokens.
        self.draft_tokens = torch.zeros(
            self.max_num_reqs,
            self.num_speculative_steps,
            dtype=torch.int64,
            device=device,
        )
73

74
75
76
        self.next_prefill_tokens = torch.zeros(
            self.max_num_reqs, dtype=torch.int32, device=device
        )
77

Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
80
81
82
83
84
85
    @property
    def num_reqs(self) -> int:
        return len(self.req_id_to_index)

    def add_request(
        self,
        req_id: str,
        prompt_len: int,
86
        all_token_ids: list[int],
Woosuk Kwon's avatar
Woosuk Kwon committed
87
88
89
90
91
92
93
        num_computed_tokens: int,
    ) -> None:
        assert len(self.free_indices) > 0, "No free indices"
        req_idx = self.free_indices.pop()
        self.req_id_to_index[req_id] = req_idx
        self.index_to_req_id[req_idx] = req_id

94
95
        self.prompt_len.np[req_idx] = prompt_len
        prefill_len = len(all_token_ids)
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
98
99
        assert prefill_len >= prompt_len, (
            f"prefill_len {prefill_len} < prompt_len {prompt_len}"
        )
        self.prefill_len.np[req_idx] = prefill_len
100
101
        self.total_len.stage_write_elem(req_idx, prefill_len)
        self.all_token_ids.stage_write(req_idx, 0, all_token_ids)
102
        self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
103
        self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
104

105
106
107
108
109
110
111
112
113
114
115
116
        if num_computed_tokens > 0 and num_computed_tokens <= prefill_len:
            # For PD disagg or resumed requests: set last_sampled to the last
            # computed token so the first decode step gets the right input_id.
            # For fresh prefill requests (num_computed_tokens == 0) the tensor
            # is not read by combine_sampled_and_draft_tokens so we skip the
            # write. Use a slice assignment rather than scalar indexing so the
            # write is dispatched through fill_ without a host/device sync.
            self.last_sampled_tokens[req_idx : req_idx + 1] = all_token_ids[
                num_computed_tokens - 1
            ]
        self.draft_tokens[req_idx].zero_()

117
    def apply_staged_writes(self) -> None:
118
        self.prompt_len.copy_to_uva()
119
        self.prefill_len.copy_to_uva()
120
121
        self.total_len.apply_write()
        self.all_token_ids.apply_write()
122
123
        self.num_computed_tokens.apply_write()

124
    def remove_request(self, req_id: str) -> bool:
Woosuk Kwon's avatar
Woosuk Kwon committed
125
126
127
        req_idx = self.req_id_to_index.pop(req_id, None)
        if req_idx is None:
            # Request not found.
128
            return False
Woosuk Kwon's avatar
Woosuk Kwon committed
129
130
        self.index_to_req_id.pop(req_idx, None)
        self.free_indices.append(req_idx)
131
        return True
132
133
134
135
136
137

    def any_prefills(self, idx_mapping_np: np.ndarray) -> bool:
        return np.any(
            self.num_computed_prefill_tokens[idx_mapping_np]
            < self.prefill_len.np[idx_mapping_np]
        )