radix_attention.py 6.78 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""Radix attention."""
Ying Sheng's avatar
Ying Sheng committed
17

18
import torch
Mingyi's avatar
Mingyi committed
19
from flashinfer.cascade import merge_state
Liangsheng Yin's avatar
Liangsheng Yin committed
20
from torch import nn
Mingyi's avatar
Mingyi committed
21

22
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
23
24
from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd
25
from sglang.srt.model_executor.model_runner import (
26
27
28
29
    ForwardMode,
    InputMetadata,
    global_server_args_dict,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
30
31
32


class RadixAttention(nn.Module):
33
    def __init__(
Ying Sheng's avatar
Ying Sheng committed
34
35
36
37
38
39
40
        self,
        num_heads: int,
        head_dim: int,
        scaling: float,
        num_kv_heads: int,
        layer_id: int,
        logit_cap: int = -1,
41
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
42
43
44
45
46
        super().__init__()
        self.tp_q_head_num = num_heads
        self.tp_k_head_num = num_kv_heads
        self.tp_v_head_num = num_kv_heads
        self.head_dim = head_dim
Ying Sheng's avatar
Ying Sheng committed
47
        self.scaling = scaling
Lianmin Zheng's avatar
Lianmin Zheng committed
48
        self.layer_id = layer_id
49

50
        if not global_server_args_dict.get("disable_flashinfer", False):
51
            self.extend_forward = self.extend_forward_flashinfer
Lianmin Zheng's avatar
Lianmin Zheng committed
52
53
54
55
56
            self.decode_forward = self.decode_forward_flashinfer
        else:
            self.extend_forward = self.extend_forward_triton
            self.decode_forward = self.decode_forward_triton

57
        self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
58
59
60
61
62
63
64
65
66
67
68
69
70

    def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
        o = torch.empty_like(q)
        self.store_kv_cache(k, v, input_metadata)
        extend_attention_fwd(
            q.view(-1, self.tp_q_head_num, self.head_dim),
            k.contiguous(),
            v.contiguous(),
            o.view(-1, self.tp_q_head_num, self.head_dim),
            input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
            input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
            input_metadata.req_to_token_pool.req_to_token,
            input_metadata.req_pool_indices,
71
            input_metadata.triton_start_loc,
Lianmin Zheng's avatar
Lianmin Zheng committed
72
            input_metadata.seq_lens,
73
            input_metadata.triton_prefix_lens,
Lianmin Zheng's avatar
Lianmin Zheng committed
74
75
            input_metadata.extend_start_loc,
            input_metadata.extend_seq_lens,
76
77
            input_metadata.triton_max_seq_len,
            input_metadata.triton_max_extend_len,
Ying Sheng's avatar
Ying Sheng committed
78
79
            sm_scale=self.scaling,
            logit_cap=self.logit_cap,
Lianmin Zheng's avatar
Lianmin Zheng committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        )

        return o

    def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
        o = torch.empty_like(q)
        self.store_kv_cache(k, v, input_metadata)

        token_attention_fwd(
            q.view(-1, self.tp_q_head_num, self.head_dim),
            input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
            input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
            o.view(-1, self.tp_q_head_num, self.head_dim),
            input_metadata.req_to_token_pool.req_to_token,
            input_metadata.req_pool_indices,
95
            input_metadata.triton_start_loc,
Lianmin Zheng's avatar
Lianmin Zheng committed
96
            input_metadata.seq_lens,
97
            input_metadata.triton_max_seq_len,
Lianmin Zheng's avatar
Lianmin Zheng committed
98
            input_metadata.total_num_tokens,
Ying Sheng's avatar
Ying Sheng committed
99
100
            sm_scale=self.scaling,
            logit_cap=self.logit_cap,
Lianmin Zheng's avatar
Lianmin Zheng committed
101
102
103
104
        )

        return o

105
    def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
106
        if not input_metadata.flashinfer_use_ragged:
107
            self.store_kv_cache(k, v, input_metadata)
Lianmin Zheng's avatar
Lianmin Zheng committed
108

Ying Sheng's avatar
Ying Sheng committed
109
            o = input_metadata.flashinfer_prefill_wrapper_paged.forward(
110
                q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
Liangsheng Yin's avatar
Liangsheng Yin committed
111
                input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
Ying Sheng's avatar
Ying Sheng committed
112
                causal=True,
Ying Sheng's avatar
Ying Sheng committed
113
                sm_scale=self.scaling,
114
115
                logits_soft_cap=self.logit_cap,
            )
Ying Sheng's avatar
Ying Sheng committed
116
117
118
119
120
121
122
123
124
125
126
        else:
            o1, s1 = (
                input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
                    q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
                    k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
                    v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
                    causal=True,
                    sm_scale=self.scaling,
                    logits_soft_cap=self.logit_cap,
                )
            )
127

Ying Sheng's avatar
Ying Sheng committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
            if input_metadata.extend_no_prefix:
                o = o1
            else:
                o2, s2 = (
                    input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
                        q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
                        input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
                        causal=False,
                        sm_scale=self.scaling,
                        logits_soft_cap=self.logit_cap,
                    )
                )

                o, _ = merge_state(o1, s1, o2, s2)

143
144
            self.store_kv_cache(k, v, input_metadata)

Ying Sheng's avatar
Ying Sheng committed
145
146
            if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
                torch.cuda.synchronize()
147

Lianmin Zheng's avatar
Lianmin Zheng committed
148
149
150
151
152
        return o.view(-1, self.tp_q_head_num * self.head_dim)

    def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
        self.store_kv_cache(k, v, input_metadata)

153
        o = input_metadata.flashinfer_decode_wrapper.forward(
Lianmin Zheng's avatar
Lianmin Zheng committed
154
            q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
Liangsheng Yin's avatar
Liangsheng Yin committed
155
            input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
Ying Sheng's avatar
Ying Sheng committed
156
            sm_scale=self.scaling,
Yueyang Pan's avatar
Yueyang Pan committed
157
            logits_soft_cap=self.logit_cap,
Lianmin Zheng's avatar
Lianmin Zheng committed
158
159
160
161
162
163
164
165
        )

        return o.view(-1, self.tp_q_head_num * self.head_dim)

    def forward(self, q, k, v, input_metadata: InputMetadata):
        k = k.view(-1, self.tp_k_head_num, self.head_dim)
        v = v.view(-1, self.tp_v_head_num, self.head_dim)

166
        if input_metadata.forward_mode == ForwardMode.EXTEND:
Lianmin Zheng's avatar
Lianmin Zheng committed
167
168
169
170
171
            return self.extend_forward(q, k, v, input_metadata)
        elif input_metadata.forward_mode == ForwardMode.DECODE:
            return self.decode_forward(q, k, v, input_metadata)

    def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
Liangsheng Yin's avatar
Liangsheng Yin committed
172
173
174
175
        k_cache = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
        v_cache = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
        k_cache[input_metadata.out_cache_loc] = cache_k
        v_cache[input_metadata.out_cache_loc] = cache_v