Unverified Commit fec185ce authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Refactor attention backend (#1381)

parent c03cece4
from __future__ import annotations
"""
Support different attention backends.
Now there are two backends: FlashInfer and Triton.
FlashInfer is faster and Triton is easier to customize.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.cascade import merge_state
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
from sglang.global_config import global_config
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
class AttentionBackend(ABC):
"""The base class of attention backends"""
@abstractmethod
def init_forward_metadata(
self, batch: ScheduleBatch, input_metadata: InputMetadata
):
pass
def forward(self, q, k, v, layer, input_metadata: InputMetadata):
if input_metadata.forward_mode.is_decode():
return self.forward_decode(q, k, v, layer, input_metadata)
else:
return self.forward_extend(q, k, v, layer, input_metadata)
class FlashInferAttnBackend(AttentionBackend):
"""Flashinfer attention kernels."""
def __init__(self, model_runner: ModelRunner):
super().__init__()
self.model_runner = model_runner
if not _grouped_size_compiled_for_decode_kernels(
model_runner.model_config.num_attention_heads // model_runner.tp_size,
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
):
self.decode_use_tensor_cores = True
else:
self.decode_use_tensor_cores = False
self.workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device="cuda",
)
if model_runner.sliding_window_size is None:
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD"
)
self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD"
)
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_tensor_cores=self.decode_use_tensor_cores,
)
else:
# Two wrappers: one for sliding window attention and one for full attention.
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
self.prefill_wrapper_ragged = None
self.prefill_wrapper_paged = []
self.decode_wrapper = []
for _ in range(2):
self.prefill_wrapper_paged.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
)
self.decode_wrapper.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_tensor_cores=self.decode_use_tensor_cores,
)
)
self.forward_metadata = None
self.cuda_graph_metadata = {}
def init_forward_metadata(
self, batch: ScheduleBatch, input_metadata: InputMetadata
):
if input_metadata.forward_mode.is_decode():
prefix_lens = None
use_ragged = False
total_num_tokens = None
else:
prefix_lens = input_metadata.extend_prefix_lens
# Some heuristics to check whether to use ragged forward
use_ragged = False
if (
int(torch.sum(input_metadata.seq_lens)) > 4096
and self.model_runner.sliding_window_size is None
):
use_ragged = True
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
update_flashinfer_indices(
input_metadata.forward_mode,
self.model_runner,
input_metadata.req_pool_indices,
input_metadata.seq_lens,
prefix_lens,
use_ragged=use_ragged,
)
self.forward_metadata = (use_ragged, total_num_tokens, self.decode_wrapper)
def init_cuda_graph_state(self, max_bs: int):
self.cuda_graph_kv_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device="cuda"
)
self.cuda_graph_kv_indices = torch.zeros(
(max_bs * self.model_runner.model_config.context_len,),
dtype=torch.int32,
device="cuda",
)
self.cuda_graph_kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device="cuda"
)
if self.model_runner.sliding_window_size is not None:
self.cuda_graph_kv_indptr = [
self.cuda_graph_kv_indptr,
self.cuda_graph_kv_indptr.clone(),
]
self.cuda_graph_kv_indices = [
self.cuda_graph_kv_indices,
self.cuda_graph_kv_indices.clone(),
]
def capture_cuda_graph_init(self, bs: int, req_pool_indices, seq_lens):
if self.model_runner.sliding_window_size is None:
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_cuda_graph=True,
use_tensor_cores=self.decode_use_tensor_cores,
paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[: bs + 1],
paged_kv_indices_buffer=self.cuda_graph_kv_indices,
paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs],
)
else:
decode_wrapper = []
for i in range(2):
decode_wrapper.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_cuda_graph=True,
use_tensor_cores=self.decode_use_tensor_cores,
paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1],
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[
:bs
],
)
)
update_flashinfer_indices(
ForwardMode.DECODE,
self.model_runner,
req_pool_indices,
seq_lens,
None,
decode_wrapper,
)
self.cuda_graph_metadata[bs] = decode_wrapper
self.forward_metadata = (False, None, decode_wrapper)
def replay_cuda_graph_init(self, bs: int, req_pool_indices, seq_lens):
update_flashinfer_indices(
ForwardMode.DECODE,
self.model_runner,
req_pool_indices[:bs],
seq_lens[:bs],
None,
self.cuda_graph_metadata[bs],
)
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
if not isinstance(self.prefill_wrapper_paged, list):
prefill_wrapper_paged = self.prefill_wrapper_paged
else:
if layer.sliding_window_size != -1:
prefill_wrapper_paged = self.prefill_wrapper_paged[0]
else:
prefill_wrapper_paged = self.prefill_wrapper_paged[1]
use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
if not use_ragged:
if k is not None:
assert v is not None
input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v
)
o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=True,
sm_scale=layer.scaling,
window_left=layer.sliding_window_size,
logits_soft_cap=layer.logit_cap,
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim),
v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
)
if input_metadata.extend_no_prefix:
o = o1
else:
o2, s2 = prefill_wrapper_paged.forward_return_lse(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=False,
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
)
o, _ = merge_state(o1, s1, o2, s2)
input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v
)
if total_num_tokens >= global_config.layer_sync_threshold:
torch.cuda.synchronize()
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
if isinstance(decode_wrapper, list):
if layer.sliding_window_size != -1:
decode_wrapper = decode_wrapper[0]
else:
decode_wrapper = decode_wrapper[1]
if k is not None:
assert v is not None
input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v
)
o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
class TritonAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
# Lazy import to avoid the initialization of cuda context
from sglang.srt.layers.triton_attention.decode_attention import (
decode_attention_fwd,
)
from sglang.srt.layers.triton_attention.extend_attention import (
extend_attention_fwd,
)
super().__init__()
self.decode_attention_fwd = decode_attention_fwd
self.extend_attention_fwd = extend_attention_fwd
self.forward_metadata = None
def init_forward_metadata(
self, batch: ScheduleBatch, input_metadata: InputMetadata
):
"""Init auxiliary variables for triton attention backend."""
if input_metadata.forward_mode.is_decode():
max_seq_len = torch.max(input_metadata.seq_lens).item()
start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32)
start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0)
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
max_extend_len = None
else:
start_loc = max_seq_len = total_num_tokens = None
prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
self.forward_metadata = start_loc, max_seq_len, max_extend_len, total_num_tokens
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)
input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v
)
start_loc, max_seq_len, max_extend_len, total_num_tokens = self.forward_metadata
self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k.contiguous(),
v.contiguous(),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
input_metadata.seq_lens,
input_metadata.extend_seq_lens,
input_metadata.extend_start_loc,
max_extend_len,
layer.scaling,
layer.logit_cap,
)
return o
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)
start_loc, max_seq_len, max_extend_len, total_num_tokens = self.forward_metadata
input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v
)
self.decode_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
start_loc,
input_metadata.seq_lens,
max_seq_len,
total_num_tokens,
layer.scaling,
layer.logit_cap,
)
return o
...@@ -10,8 +10,8 @@ def create_flashinfer_kv_indices_triton( ...@@ -10,8 +10,8 @@ def create_flashinfer_kv_indices_triton(
page_kernel_lens_ptr, page_kernel_lens_ptr,
kv_indptr, kv_indptr,
kv_start_idx, kv_start_idx,
max_context_len,
kv_indices_ptr, kv_indices_ptr,
max_context_len: tl.constexpr,
): ):
BLOCK_SIZE: tl.constexpr = 512 BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
...@@ -47,15 +47,15 @@ class FlashinferUpdater: ...@@ -47,15 +47,15 @@ class FlashinferUpdater:
req_pool_indices, req_pool_indices,
seq_lens, seq_lens,
prefix_lens, prefix_lens,
flashinfer_decode_wrapper=None, decode_wrapper=None,
flashinfer_use_ragged=False, use_ragged=False,
): ):
self.forward_mode = forward_mode self.forward_mode = forward_mode
self.model_runner = model_runner self.model_runner = model_runner
self.req_pool_indices = req_pool_indices self.req_pool_indices = req_pool_indices
self.seq_lens = seq_lens self.seq_lens = seq_lens
self.prefix_lens = prefix_lens self.prefix_lens = prefix_lens
self.flashinfer_use_ragged = flashinfer_use_ragged self.use_ragged = use_ragged
self.num_qo_heads = ( self.num_qo_heads = (
model_runner.model_config.num_attention_heads // model_runner.tp_size model_runner.model_config.num_attention_heads // model_runner.tp_size
...@@ -71,20 +71,17 @@ class FlashinferUpdater: ...@@ -71,20 +71,17 @@ class FlashinferUpdater:
) )
( (
self.flashinfer_decode_wrapper, self.decode_wrapper,
self.flashinfer_prefill_wrapper_ragged, self.prefill_wrapper_ragged,
self.flashinfer_prefill_wrapper_paged, self.prefill_wrapper_paged,
) = ( ) = (
flashinfer_decode_wrapper, decode_wrapper or self.model_runner.attn_backend.decode_wrapper,
self.model_runner.flashinfer_prefill_wrapper_ragged, self.model_runner.attn_backend.prefill_wrapper_ragged,
self.model_runner.flashinfer_prefill_wrapper_paged, self.model_runner.attn_backend.prefill_wrapper_paged,
) )
# CUDA graph uses different flashinfer_decode_wrapper
if self.flashinfer_decode_wrapper is None:
self.flashinfer_decode_wrapper = self.model_runner.flashinfer_decode_wrapper
def _init_indices_no_window(self): def _init_indices_no_sliding_window(self):
if self.flashinfer_use_ragged: if self.use_ragged:
paged_kernel_lens = self.prefix_lens paged_kernel_lens = self.prefix_lens
else: else:
paged_kernel_lens = self.seq_lens paged_kernel_lens = self.seq_lens
...@@ -103,13 +100,13 @@ class FlashinferUpdater: ...@@ -103,13 +100,13 @@ class FlashinferUpdater:
paged_kernel_lens, paged_kernel_lens,
self.kv_indptr, self.kv_indptr,
None, None,
self.model_runner.req_to_token_pool.req_to_token.size(1),
self.kv_indices, self.kv_indices,
self.model_runner.req_to_token_pool.req_to_token.size(1),
) )
def _init_indices_window(self, wrapper_id): def _init_indices_sliding_window(self, wrapper_id):
# window attention use paged only
if wrapper_id == 0: if wrapper_id == 0:
# window attention use paged only
if self.forward_mode.is_decode(): if self.forward_mode.is_decode():
paged_kernel_lens = torch.minimum( paged_kernel_lens = torch.minimum(
self.seq_lens, self.seq_lens,
...@@ -123,6 +120,7 @@ class FlashinferUpdater: ...@@ -123,6 +120,7 @@ class FlashinferUpdater:
- self.prefix_lens, - self.prefix_lens,
) )
else: else:
# full attention
paged_kernel_lens = self.seq_lens paged_kernel_lens = self.seq_lens
kv_start_idx = self.seq_lens - paged_kernel_lens kv_start_idx = self.seq_lens - paged_kernel_lens
...@@ -139,8 +137,8 @@ class FlashinferUpdater: ...@@ -139,8 +137,8 @@ class FlashinferUpdater:
paged_kernel_lens, paged_kernel_lens,
self.kv_indptr, self.kv_indptr,
kv_start_idx, kv_start_idx,
self.model_runner.req_to_token_pool.req_to_token.size(1),
self.kv_indices, self.kv_indices,
self.model_runner.req_to_token_pool.req_to_token.size(1),
) )
def _update_decode_indices(self, decode_wrapper): def _update_decode_indices(self, decode_wrapper):
...@@ -164,7 +162,7 @@ class FlashinferUpdater: ...@@ -164,7 +162,7 @@ class FlashinferUpdater:
) )
qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0) qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0)
if self.flashinfer_use_ragged: if self.use_ragged:
ragged_wrapper.end_forward() ragged_wrapper.end_forward()
ragged_wrapper.begin_forward( ragged_wrapper.begin_forward(
qo_indptr, qo_indptr,
...@@ -187,28 +185,28 @@ class FlashinferUpdater: ...@@ -187,28 +185,28 @@ class FlashinferUpdater:
1, 1,
) )
def update_indices_no_window(self): def update_indices_no_sliding_window(self):
self._init_indices_no_window() self._init_indices_no_sliding_window()
if self.forward_mode.is_decode(): if self.forward_mode.is_decode():
self._update_decode_indices(self.flashinfer_decode_wrapper) self._update_decode_indices(self.decode_wrapper)
else: else:
self._update_extend_indices( self._update_extend_indices(
self.flashinfer_prefill_wrapper_ragged, self.prefill_wrapper_ragged,
self.flashinfer_prefill_wrapper_paged, self.prefill_wrapper_paged,
) )
def update_indices_window(self): def update_indices_sliding_window(self):
assert self.flashinfer_use_ragged is False assert self.use_ragged is False
for wrapper_id in range(2): for wrapper_id in range(2):
self._init_indices_window(wrapper_id) self._init_indices_sliding_window(wrapper_id)
if self.forward_mode.is_decode(): if self.forward_mode.is_decode():
self._update_decode_indices(self.flashinfer_decode_wrapper[wrapper_id]) self._update_decode_indices(self.decode_wrapper[wrapper_id])
else: else:
self._update_extend_indices( self._update_extend_indices(
None, None,
self.flashinfer_prefill_wrapper_paged[wrapper_id], self.prefill_wrapper_paged[wrapper_id],
) )
...@@ -218,20 +216,20 @@ def update_flashinfer_indices( ...@@ -218,20 +216,20 @@ def update_flashinfer_indices(
req_pool_indices, req_pool_indices,
seq_lens, seq_lens,
prefix_lens, prefix_lens,
flashinfer_decode_wrapper=None, decode_wrapper=None,
flashinfer_use_ragged=False, use_ragged=False,
): ):
flashinfer_updater = FlashinferUpdater( updater = FlashinferUpdater(
forward_mode, forward_mode,
model_runner, model_runner,
req_pool_indices, req_pool_indices,
seq_lens, seq_lens,
prefix_lens, prefix_lens,
flashinfer_decode_wrapper, decode_wrapper,
flashinfer_use_ragged, use_ragged,
) )
if model_runner.sliding_window_size is None: if model_runner.sliding_window_size is None:
flashinfer_updater.update_indices_no_window() updater.update_indices_no_sliding_window()
else: else:
flashinfer_updater.update_indices_window() updater.update_indices_sliding_window()
...@@ -15,25 +15,14 @@ limitations under the License. ...@@ -15,25 +15,14 @@ limitations under the License.
"""Radix attention.""" """Radix attention."""
from typing import Optional
import torch
from flashinfer.cascade import merge_state
from torch import nn from torch import nn
from sglang.global_config import global_config from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.layers.triton_attention.decode_attention import decode_attention_fwd
from sglang.srt.layers.triton_attention.extend_attention import extend_attention_fwd
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.model_executor.model_runner import global_server_args_dict
class RadixAttention(nn.Module): class RadixAttention(nn.Module):
""" """
The attention layer implementation. The attention layer implementation.
Now it has two backends: FlashInfer and Triton.
FlashInfer is faster and Triton is easier to customize.
It supports two operators: extend (i.e. prefill with cached prefix) and decode.
""" """
def __init__( def __init__(
...@@ -43,8 +32,8 @@ class RadixAttention(nn.Module): ...@@ -43,8 +32,8 @@ class RadixAttention(nn.Module):
scaling: float, scaling: float,
num_kv_heads: int, num_kv_heads: int,
layer_id: int, layer_id: int,
sliding_window_size: Optional[int] = None, sliding_window_size: int = -1,
logit_cap: int = -1, logit_cap: float = 0.0,
v_head_dim: int = -1, v_head_dim: int = -1,
): ):
super().__init__() super().__init__()
...@@ -56,164 +45,14 @@ class RadixAttention(nn.Module): ...@@ -56,164 +45,14 @@ class RadixAttention(nn.Module):
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
self.scaling = scaling self.scaling = scaling
self.layer_id = layer_id self.layer_id = layer_id
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0 self.logit_cap = logit_cap
self.sliding_window_size = sliding_window_size if sliding_window_size else -1 self.sliding_window_size = sliding_window_size or -1
# Choose backend
if (
global_server_args_dict["attention_backend"] == "flashinfer"
and self.qk_head_dim == self.v_head_dim
):
self.extend_forward = self.extend_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer
elif global_server_args_dict["attention_backend"] == "triton":
self.extend_forward = self.extend_forward_triton
self.decode_forward = self.decode_forward_triton
else:
raise ValueError(
f"Invalid attention backend: {global_server_args_dict['attention_backend']}"
)
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
if self.qk_head_dim != self.v_head_dim:
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
else:
o = torch.empty_like(q)
self.store_kv_cache(k, v, input_metadata)
extend_attention_fwd(
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
k.contiguous(),
v.contiguous(),
o.view(-1, self.tp_q_head_num, self.v_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
input_metadata.triton_start_loc,
input_metadata.seq_lens,
input_metadata.triton_prefix_lens,
input_metadata.extend_start_loc,
input_metadata.extend_seq_lens,
input_metadata.triton_max_seq_len,
input_metadata.triton_max_extend_len,
sm_scale=self.scaling,
logit_cap=self.logit_cap,
)
return o
def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
if self.qk_head_dim != self.v_head_dim:
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
else:
o = torch.empty_like(q)
self.store_kv_cache(k, v, input_metadata)
decode_attention_fwd(
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
o.view(-1, self.tp_q_head_num, self.v_head_dim),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
input_metadata.triton_start_loc,
input_metadata.seq_lens,
input_metadata.triton_max_seq_len,
input_metadata.total_num_tokens,
sm_scale=self.scaling,
logit_cap=self.logit_cap,
)
return o
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
if self.sliding_window_size != -1:
prefill_wrapper_paged = prefill_wrapper_paged[0]
else:
if isinstance(prefill_wrapper_paged, list):
prefill_wrapper_paged = prefill_wrapper_paged[1]
if not input_metadata.flashinfer_use_ragged:
if k is not None:
assert v is not None
self.store_kv_cache(k, v, input_metadata)
o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
causal=True,
sm_scale=self.scaling,
window_left=self.sliding_window_size,
logits_soft_cap=self.logit_cap,
)
else:
o1, s1 = (
input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
causal=True,
sm_scale=self.scaling,
logits_soft_cap=self.logit_cap,
)
)
if input_metadata.extend_no_prefix:
o = o1
else:
o2, s2 = prefill_wrapper_paged.forward_return_lse(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
causal=False,
sm_scale=self.scaling,
logits_soft_cap=self.logit_cap,
)
o, _ = merge_state(o1, s1, o2, s2)
self.store_kv_cache(k, v, input_metadata)
if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
torch.cuda.synchronize()
return o.view(-1, self.tp_q_head_num * self.head_dim)
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
decode_wrapper = input_metadata.flashinfer_decode_wrapper
if self.sliding_window_size != -1:
decode_wrapper = decode_wrapper[0]
else:
if isinstance(decode_wrapper, list):
decode_wrapper = decode_wrapper[1]
if k is not None:
assert v is not None
self.store_kv_cache(k, v, input_metadata)
o = decode_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
sm_scale=self.scaling,
logits_soft_cap=self.logit_cap,
)
return o.view(-1, self.tp_q_head_num * self.head_dim)
def forward(self, q, k, v, input_metadata: InputMetadata): def forward(self, q, k, v, input_metadata: InputMetadata):
if k is not None: if k is not None:
# For cross-layer sharing, kv can be None
assert v is not None assert v is not None
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim) v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
if input_metadata.forward_mode.is_extend(): return input_metadata.attn_backend.forward(q, k, v, self, input_metadata)
return self.extend_forward(q, k, v, input_metadata)
elif input_metadata.forward_mode.is_decode():
return self.decode_forward(q, k, v, input_metadata)
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
input_metadata.token_to_kv_pool.set_kv_buffer(
self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v
)
...@@ -15,6 +15,7 @@ limitations under the License. ...@@ -15,6 +15,7 @@ limitations under the License.
""" """
Memory-efficient attention for decoding. Memory-efficient attention for decoding.
It supports page size = 1.
""" """
# Adapted from # Adapted from
...@@ -197,7 +198,6 @@ def _decode_att_m_fwd( ...@@ -197,7 +198,6 @@ def _decode_att_m_fwd(
logit_cap, logit_cap,
): ):
BLOCK = 32 BLOCK = 32
# shape constraints
Lq, Lk = q.shape[-1], k_buffer.shape[-1] Lq, Lk = q.shape[-1], k_buffer.shape[-1]
batch, head_num = B_req_idx.shape[0], q.shape[1] batch, head_num = B_req_idx.shape[0], q.shape[1]
...@@ -478,7 +478,6 @@ def _decode_grouped_att_m_fwd( ...@@ -478,7 +478,6 @@ def _decode_grouped_att_m_fwd(
logit_cap, logit_cap,
): ):
BLOCK = 32 BLOCK = 32
# shape constraints
Lq, Lk = q.shape[-1], k_buffer.shape[-1] Lq, Lk = q.shape[-1], k_buffer.shape[-1]
if Lk == 576: if Lk == 576:
...@@ -570,9 +569,9 @@ def _decode_grouped_softmax_reducev_fwd( ...@@ -570,9 +569,9 @@ def _decode_grouped_softmax_reducev_fwd(
BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK, BLOCK_N=BLOCK,
BLOCK_H=BLOCK_H, BLOCK_H=BLOCK_H,
Lv=Lv,
num_warps=num_warps, num_warps=num_warps,
num_stages=1, num_stages=1,
Lv=Lv,
) )
...@@ -588,7 +587,7 @@ def decode_attention_fwd( ...@@ -588,7 +587,7 @@ def decode_attention_fwd(
max_len_in_batch, max_len_in_batch,
total_num_tokens, total_num_tokens,
sm_scale, sm_scale,
logit_cap=-1, logit_cap=0.0,
att_m=None, att_m=None,
): ):
if att_m is None: if att_m is None:
......
...@@ -61,14 +61,14 @@ def _fwd_kernel( ...@@ -61,14 +61,14 @@ def _fwd_kernel(
stride_buf_vbs, stride_buf_vbs,
stride_buf_vh, stride_buf_vh,
stride_req_to_tokens_b, stride_req_to_tokens_b,
logit_cap: tl.constexpr,
Lq: tl.constexpr,
Lv: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_DPE: tl.constexpr, BLOCK_DPE: tl.constexpr,
BLOCK_DV: tl.constexpr, BLOCK_DV: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
logit_cap: tl.constexpr,
Lq: tl.constexpr,
Lv: tl.constexpr,
): ):
cur_seq = tl.program_id(0) cur_seq = tl.program_id(0)
cur_head = tl.program_id(1) cur_head = tl.program_id(1)
...@@ -111,7 +111,7 @@ def _fwd_kernel( ...@@ -111,7 +111,7 @@ def _fwd_kernel(
) )
qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0) qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
# stage1: compute scores with prefix # stage 1: compute scores with prefix
offs_n = tl.arange(0, BLOCK_N) offs_n = tl.arange(0, BLOCK_N)
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
...@@ -174,7 +174,7 @@ def _fwd_kernel( ...@@ -174,7 +174,7 @@ def _fwd_kernel(
e_max = n_e_max e_max = n_e_max
# stage2: compute the trianlge part # stage 2: compute the trianlge part
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M) cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
for start_n in range(0, cur_block_m_end, BLOCK_N): for start_n in range(0, cur_block_m_end, BLOCK_N):
...@@ -255,26 +255,22 @@ def extend_attention_fwd( ...@@ -255,26 +255,22 @@ def extend_attention_fwd(
v_buffer, v_buffer,
req_to_tokens, req_to_tokens,
b_req_idx, b_req_idx,
b_start_loc,
b_seq_len, b_seq_len,
b_seq_len_prefix,
b_start_loc_extend,
b_seq_len_extend, b_seq_len_extend,
max_len_in_batch, b_start_loc_extend,
max_len_extend, max_len_extend,
sm_scale=None, sm_scale=None,
logit_cap=-1, logit_cap=0.0,
): ):
""" """
q_extend, k_extend, v_extend, o_extend: contiguous tensors q_extend, k_extend, v_extend, o_extend: contiguous tensors
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
""" """
Lq, Lk, Lv, Lo = ( Lq, Lk, Lv = (
q_extend.shape[-1], q_extend.shape[-1],
k_extend.shape[-1], k_extend.shape[-1],
v_extend.shape[-1], v_extend.shape[-1],
o_extend.shape[-1],
) )
if Lq == 576: if Lq == 576:
...@@ -303,7 +299,7 @@ def extend_attention_fwd( ...@@ -303,7 +299,7 @@ def extend_attention_fwd(
else: else:
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32) BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale sm_scale = sm_scale or 1.0 / (Lq**0.5)
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1] batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
kv_group_num = q_extend.shape[1] // k_extend.shape[1] kv_group_num = q_extend.shape[1] // k_extend.shape[1]
...@@ -338,27 +334,24 @@ def extend_attention_fwd( ...@@ -338,27 +334,24 @@ def extend_attention_fwd(
v_buffer.stride(0), v_buffer.stride(0),
v_buffer.stride(1), v_buffer.stride(1),
req_to_tokens.stride(0), req_to_tokens.stride(0),
logit_cap=logit_cap,
BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE, BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV, BLOCK_DV=BLOCK_DV,
BLOCK_M=BLOCK_M, BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N, BLOCK_N=BLOCK_N,
num_warps=num_warps,
num_stages=num_stages,
logit_cap=logit_cap,
Lq=Lq, Lq=Lq,
Lv=Lv, Lv=Lv,
num_warps=num_warps,
num_stages=num_stages,
) )
def redundant_attention( def redundant_attention(
q_extend, q_extend,
k_extend,
v_extend,
o_extend, o_extend,
k_buffer, k_buffer,
v_buffer, v_buffer,
req_to_tokens,
b_req_idx, b_req_idx,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
......
...@@ -368,7 +368,7 @@ class ScheduleBatch: ...@@ -368,7 +368,7 @@ class ScheduleBatch:
) )
def batch_size(self): def batch_size(self):
return len(self.reqs) if self.reqs is not None else 0 return len(self.reqs) if self.reqs else 0
def is_empty(self): def is_empty(self):
return len(self.reqs) == 0 return len(self.reqs) == 0
......
...@@ -13,15 +13,13 @@ See the License for the specific language governing permissions and ...@@ -13,15 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
"""Run the model with cuda graph.""" """Run the model with cuda graph and torch.compile."""
import bisect import bisect
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable, List from typing import Callable
import torch import torch
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
...@@ -55,6 +53,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False): ...@@ -55,6 +53,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
def patch_model( def patch_model(
model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator" model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
): ):
"""Patch the model to make it compatible with with torch.compile"""
backup_ca_comm = None backup_ca_comm = None
try: try:
...@@ -86,23 +85,28 @@ def set_torch_compile_config(): ...@@ -86,23 +85,28 @@ def set_torch_compile_config():
class CudaGraphRunner: class CudaGraphRunner:
def __init__( """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
self,
model_runner: "ModelRunner", def __init__(self, model_runner: "ModelRunner"):
max_batch_size_to_capture: int, # Parse args
use_torch_compile: bool,
disable_padding: bool,
):
self.model_runner = model_runner self.model_runner = model_runner
self.graphs = {} self.graphs = {}
self.input_buffers = {} self.input_buffers = {}
self.output_buffers = {} self.output_buffers = {}
self.flashinfer_handlers = {} self.flashinfer_handlers = {}
self.graph_memory_pool = None self.graph_memory_pool = None
self.disable_padding = disable_padding self.use_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
# Batch sizes to capture
if self.model_runner.server_args.disable_cuda_graph_padding:
self.capture_bs = list(range(1, 32)) + [64, 128]
else:
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if self.use_torch_compile else []
# Common inputs # Common inputs
self.max_bs = max_batch_size_to_capture self.max_bs = max(self.capture_bs)
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda") self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
self.req_pool_indices = torch.zeros( self.req_pool_indices = torch.zeros(
(self.max_bs,), dtype=torch.int32, device="cuda" (self.max_bs,), dtype=torch.int32, device="cuda"
...@@ -115,56 +119,39 @@ class CudaGraphRunner: ...@@ -115,56 +119,39 @@ class CudaGraphRunner:
(self.max_bs,), dtype=torch.int32, device="cuda" (self.max_bs,), dtype=torch.int32, device="cuda"
) )
# FlashInfer inputs # Attention backend
self.flashinfer_kv_indptr = torch.zeros( self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
)
self.flashinfer_kv_indices = torch.zeros(
(self.max_bs * model_runner.model_config.context_len,),
dtype=torch.int32,
device="cuda",
)
self.flashinfer_kv_last_page_len = torch.ones(
(self.max_bs,), dtype=torch.int32, device="cuda"
)
if model_runner.sliding_window_size is None:
self.flashinfer_workspace_buffer = (
self.model_runner.flashinfer_workspace_buffer
)
else:
self.flashinfer_workspace_buffer = (
self.model_runner.flashinfer_workspace_buffer
)
self.flashinfer_kv_indptr = [
self.flashinfer_kv_indptr,
self.flashinfer_kv_indptr.clone(),
]
self.flashinfer_kv_indices = [
self.flashinfer_kv_indices,
self.flashinfer_kv_indices.clone(),
]
# Sampling inputs # Sampling info
vocab_size = model_runner.model_config.vocab_size vocab_size = model_runner.model_config.vocab_size
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size) self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else [] if self.use_torch_compile:
if use_torch_compile:
set_torch_compile_config() set_torch_compile_config()
# Capture
try:
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n"
"Possible solutions:\n"
"1. disable cuda graph by --disable-cuda-graph\n"
"2. set --mem-fraction-static to a smaller value\n"
"3. disable torch compile by not using --enable-torch-compile\n"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
)
def can_run(self, batch_size: int): def can_run(self, batch_size: int):
if self.disable_padding: if self.disable_padding:
return batch_size in self.graphs return batch_size in self.graphs
else: else:
return batch_size <= self.max_bs return batch_size <= self.max_bs
def capture(self, batch_size_list: List[int]): def capture(self):
self.batch_size_list = batch_size_list
with graph_capture() as graph_capture_context: with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream self.stream = graph_capture_context.stream
for bs in batch_size_list: for bs in self.capture_bs:
with patch_model( with patch_model(
self.model_runner.model, self.model_runner.model,
bs in self.compile_bs, bs in self.compile_bs,
...@@ -172,14 +159,10 @@ class CudaGraphRunner: ...@@ -172,14 +159,10 @@ class CudaGraphRunner:
) as forward: ) as forward:
( (
graph, graph,
input_buffers,
output_buffers, output_buffers,
flashinfer_handler,
) = self.capture_one_batch_size(bs, forward) ) = self.capture_one_batch_size(bs, forward)
self.graphs[bs] = graph self.graphs[bs] = graph
self.input_buffers[bs] = input_buffers
self.output_buffers[bs] = output_buffers self.output_buffers[bs] = output_buffers
self.flashinfer_handlers[bs] = flashinfer_handler
def capture_one_batch_size(self, bs: int, forward: Callable): def capture_one_batch_size(self, bs: int, forward: Callable):
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
...@@ -192,48 +175,9 @@ class CudaGraphRunner: ...@@ -192,48 +175,9 @@ class CudaGraphRunner:
position_ids_offsets = self.position_ids_offsets[:bs] position_ids_offsets = self.position_ids_offsets[:bs]
out_cache_loc = self.out_cache_loc[:bs] out_cache_loc = self.out_cache_loc[:bs]
# FlashInfer inputs # Attention backend
if not _grouped_size_compiled_for_decode_kernels( self.model_runner.attn_backend.capture_cuda_graph_init(
self.model_runner.model_config.num_attention_heads bs, req_pool_indices, seq_lens
// self.model_runner.tp_size,
self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
):
use_tensor_cores = True
else:
use_tensor_cores = False
if self.model_runner.sliding_window_size is None:
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer,
"NHD",
use_cuda_graph=True,
use_tensor_cores=use_tensor_cores,
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
paged_kv_indices_buffer=self.flashinfer_kv_indices,
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
)
else:
flashinfer_decode_wrapper = []
for i in range(2):
flashinfer_decode_wrapper.append(
BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer,
"NHD",
use_cuda_graph=True,
use_tensor_cores=use_tensor_cores,
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1],
paged_kv_indices_buffer=self.flashinfer_kv_indices[i],
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[
:bs
],
)
)
update_flashinfer_indices(
ForwardMode.DECODE,
self.model_runner,
req_pool_indices,
seq_lens,
None,
flashinfer_decode_wrapper,
) )
# Run and capture # Run and capture
...@@ -246,13 +190,12 @@ class CudaGraphRunner: ...@@ -246,13 +190,12 @@ class CudaGraphRunner:
seq_lens=seq_lens, seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool, req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool,
attn_backend=self.model_runner.attn_backend,
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
return_logprob=False, return_logprob=False,
top_logprobs_nums=0, top_logprobs_nums=0,
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64), positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
) )
return forward(input_ids, input_metadata.positions, input_metadata) return forward(input_ids, input_metadata.positions, input_metadata)
for _ in range(2): for _ in range(2):
...@@ -274,15 +217,15 @@ class CudaGraphRunner: ...@@ -274,15 +217,15 @@ class CudaGraphRunner:
self.model_runner.tp_group.barrier() self.model_runner.tp_group.barrier()
self.graph_memory_pool = graph.pool() self.graph_memory_pool = graph.pool()
return graph, None, out, flashinfer_decode_wrapper return graph, out
def replay(self, batch: ScheduleBatch): def replay(self, batch: ScheduleBatch):
assert batch.out_cache_loc is not None assert batch.out_cache_loc is not None
raw_bs = len(batch.reqs) raw_bs = len(batch.reqs)
# Pad # Pad
index = bisect.bisect_left(self.batch_size_list, raw_bs) index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.batch_size_list[index] bs = self.capture_bs[index]
if bs != raw_bs: if bs != raw_bs:
self.seq_lens.zero_() self.seq_lens.zero_()
self.position_ids_offsets.fill_(1) self.position_ids_offsets.fill_(1)
...@@ -295,14 +238,9 @@ class CudaGraphRunner: ...@@ -295,14 +238,9 @@ class CudaGraphRunner:
self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
self.out_cache_loc[:raw_bs] = batch.out_cache_loc self.out_cache_loc[:raw_bs] = batch.out_cache_loc
# FlashInfer inputs # Attention backend
update_flashinfer_indices( self.model_runner.attn_backend.replay_cuda_graph_init(
ForwardMode.DECODE, bs, self.req_pool_indices, self.seq_lens
self.model_runner,
self.req_pool_indices[:bs],
self.seq_lens[:bs],
None,
self.flashinfer_handlers[bs],
) )
# Sampling inputs # Sampling inputs
......
...@@ -23,9 +23,8 @@ from typing import TYPE_CHECKING, List ...@@ -23,9 +23,8 @@ from typing import TYPE_CHECKING, List
import numpy as np import numpy as np
import torch import torch
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
...@@ -66,12 +65,11 @@ class InputMetadata: ...@@ -66,12 +65,11 @@ class InputMetadata:
seq_lens: torch.Tensor seq_lens: torch.Tensor
req_to_token_pool: ReqToTokenPool req_to_token_pool: ReqToTokenPool
token_to_kv_pool: BaseTokenToKVPool token_to_kv_pool: BaseTokenToKVPool
attn_backend: AttentionBackend
# Output location of the KV cache # Output location of the KV cache
out_cache_loc: torch.Tensor out_cache_loc: torch.Tensor
total_num_tokens: int = None
# Position information # Position information
positions: torch.Tensor = None positions: torch.Tensor = None
...@@ -93,18 +91,6 @@ class InputMetadata: ...@@ -93,18 +91,6 @@ class InputMetadata:
image_offsets: List[List[int]] = None image_offsets: List[List[int]] = None
modalities: List[List[str]] = None modalities: List[List[str]] = None
# Trition attention backend
triton_max_seq_len: int = 0
triton_max_extend_len: int = 0
triton_start_loc: torch.Tensor = None
triton_prefix_lens: torch.Tensor = None
# FlashInfer attention backend
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
flashinfer_use_ragged: bool = False
def init_multimuldal_info(self, batch: ScheduleBatch): def init_multimuldal_info(self, batch: ScheduleBatch):
reqs = batch.reqs reqs = batch.reqs
self.pixel_values = [r.pixel_values for r in reqs] self.pixel_values = [r.pixel_values for r in reqs]
...@@ -154,32 +140,27 @@ class InputMetadata: ...@@ -154,32 +140,27 @@ class InputMetadata:
self.positions = self.positions.to(torch.int64) self.positions = self.positions.to(torch.int64)
def compute_extend_infos(self, batch: ScheduleBatch): def compute_extend_infos(self, batch: ScheduleBatch):
if self.forward_mode.is_decode(): extend_lens_cpu = [
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None len(r.fill_ids) - batch.prefix_lens_cpu[i] for i, r in enumerate(batch.reqs)
self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None ]
else: self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
extend_lens_cpu = [ self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
len(r.fill_ids) - batch.prefix_lens_cpu[i] self.extend_start_loc = torch.zeros_like(self.seq_lens)
for i, r in enumerate(batch.reqs) self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
] self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") self.extend_seq_lens_cpu = extend_lens_cpu
self.extend_start_loc = torch.zeros_like(self.seq_lens) self.logprob_start_lens_cpu = [
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) (
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu) min(
req.logprob_start_len - batch.prefix_lens_cpu[i],
self.extend_seq_lens_cpu = extend_lens_cpu extend_lens_cpu[i] - 1,
self.logprob_start_lens_cpu = [
(
min(
req.logprob_start_len - batch.prefix_lens_cpu[i],
extend_lens_cpu[i] - 1,
)
if req.logprob_start_len >= batch.prefix_lens_cpu[i]
else extend_lens_cpu[i] - 1 # Fake extend, actually decode
) )
for i, req in enumerate(batch.reqs) if req.logprob_start_len >= batch.prefix_lens_cpu[i]
] else extend_lens_cpu[i] - 1 # Fake extend, actually decode
)
for i, req in enumerate(batch.reqs)
]
@classmethod @classmethod
def from_schedule_batch( def from_schedule_batch(
...@@ -195,6 +176,7 @@ class InputMetadata: ...@@ -195,6 +176,7 @@ class InputMetadata:
seq_lens=batch.seq_lens, seq_lens=batch.seq_lens,
req_to_token_pool=model_runner.req_to_token_pool, req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool, token_to_kv_pool=model_runner.token_to_kv_pool,
attn_backend=model_runner.attn_backend,
out_cache_loc=batch.out_cache_loc, out_cache_loc=batch.out_cache_loc,
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
...@@ -202,76 +184,12 @@ class InputMetadata: ...@@ -202,76 +184,12 @@ class InputMetadata:
ret.sampling_info.update_penalties() ret.sampling_info.update_penalties()
ret.sampling_info.update_regex_vocab_mask(batch) ret.sampling_info.update_regex_vocab_mask(batch)
ret.compute_positions(batch) ret.compute_positions(batch)
ret.compute_extend_infos(batch) if not batch.forward_mode.is_decode():
fm = batch.forward_mode
if not fm.is_decode() or model_runner.server_args.attention_backend == "triton":
ret.total_num_tokens = int(torch.sum(ret.seq_lens))
if not fm.is_decode():
ret.init_multimuldal_info(batch) ret.init_multimuldal_info(batch)
ret.compute_extend_infos(batch)
if model_runner.server_args.attention_backend == "triton": model_runner.attn_backend.init_forward_metadata(batch, ret)
ret.init_triton_args(batch)
flashinfer_use_ragged = False
if model_runner.server_args.attention_backend == "flashinfer":
if (
not fm.is_decode()
and int(torch.sum(ret.seq_lens)) > 4096
and model_runner.sliding_window_size is None
):
flashinfer_use_ragged = True
ret.init_flashinfer_handlers(
model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged
)
return ret return ret
def init_triton_args(self, batch: ScheduleBatch):
"""Init auxiliary variables for triton attention backend."""
self.triton_max_seq_len = int(torch.max(self.seq_lens))
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
if self.forward_mode.is_decode():
self.triton_max_extend_len = None
else:
self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
extend_seq_lens = self.seq_lens - self.triton_prefix_lens
self.triton_max_extend_len = int(torch.max(extend_seq_lens))
def init_flashinfer_handlers(
self,
model_runner,
prefix_lens_cpu,
flashinfer_use_ragged,
):
if self.forward_mode.is_decode():
prefix_lens = None
else:
prefix_lens = self.extend_prefix_lens
update_flashinfer_indices(
self.forward_mode,
model_runner,
self.req_pool_indices,
self.seq_lens,
prefix_lens,
flashinfer_use_ragged=flashinfer_use_ragged,
)
(
self.flashinfer_prefill_wrapper_ragged,
self.flashinfer_prefill_wrapper_paged,
self.flashinfer_decode_wrapper,
self.flashinfer_use_ragged,
) = (
model_runner.flashinfer_prefill_wrapper_ragged,
model_runner.flashinfer_prefill_wrapper_paged,
model_runner.flashinfer_decode_wrapper,
flashinfer_use_ragged,
)
...@@ -25,12 +25,6 @@ from typing import Optional, Tuple, Type ...@@ -25,12 +25,6 @@ from typing import Optional, Tuple, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
from vllm.config import DeviceConfig, LoadConfig from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig from vllm.config import ModelConfig as VllmModelConfig
from vllm.distributed import ( from vllm.distributed import (
...@@ -43,8 +37,8 @@ from vllm.distributed.parallel_state import in_the_same_node_as ...@@ -43,8 +37,8 @@ from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from sglang.global_config import global_config
from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import SampleOutput from sglang.srt.layers.sampler import SampleOutput
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
...@@ -69,6 +63,8 @@ logger = logging.getLogger(__name__) ...@@ -69,6 +63,8 @@ logger = logging.getLogger(__name__)
class ModelRunner: class ModelRunner:
"""ModelRunner runs the forward passes of the models."""
def __init__( def __init__(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
...@@ -100,6 +96,7 @@ class ModelRunner: ...@@ -100,6 +96,7 @@ class ModelRunner:
} }
) )
# Model-specific adjustment
if self.is_multimodal_model: if self.is_multimodal_model:
logger.info( logger.info(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models." "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
...@@ -107,6 +104,7 @@ class ModelRunner: ...@@ -107,6 +104,7 @@ class ModelRunner:
server_args.chunked_prefill_size = None server_args.chunked_prefill_size = None
server_args.mem_fraction_static *= 0.95 server_args.mem_fraction_static *= 0.95
# Init componnets
min_per_gpu_memory = self.init_torch_distributed() min_per_gpu_memory = self.init_torch_distributed()
self.load_model() self.load_model()
self.init_memory_pool( self.init_memory_pool(
...@@ -115,7 +113,7 @@ class ModelRunner: ...@@ -115,7 +113,7 @@ class ModelRunner:
server_args.max_total_tokens, server_args.max_total_tokens,
) )
self.init_cublas() self.init_cublas()
self.init_flashinfer() self.init_attention_backend()
self.init_cuda_graphs() self.init_cuda_graphs()
def init_torch_distributed(self): def init_torch_distributed(self):
...@@ -397,9 +395,6 @@ class ModelRunner: ...@@ -397,9 +395,6 @@ class ModelRunner:
qk_rope_head_dim=self.model_config.qk_rope_head_dim, qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.model_config.num_hidden_layers, layer_num=self.model_config.num_hidden_layers,
) )
logger.info("using MLA Triton implementaion, flashinfer is disabled")
# FIXME: temporarily only Triton MLA is supported
self.server_args.attention_backend = "triton"
else: else:
self.token_to_kv_pool = MHATokenToKVPool( self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens, self.max_total_num_tokens,
...@@ -422,106 +417,42 @@ class ModelRunner: ...@@ -422,106 +417,42 @@ class ModelRunner:
c = a @ b c = a @ b
return c return c
def init_flashinfer(self): def init_attention_backend(self):
"""Init flashinfer attention kernel wrappers.""" """Init attention kernel backend."""
if self.server_args.attention_backend != "flashinfer": if self.server_args.attention_backend == "flashinfer":
assert ( self.attn_backend = FlashInferAttnBackend(self)
self.sliding_window_size is None elif self.server_args.attention_backend == "triton":
), "turn on flashinfer to support window attention" assert self.sliding_window_size is None, (
self.flashinfer_prefill_wrapper_ragged = None "Window attention is not supported in the triton attention backend. "
self.flashinfer_prefill_wrapper_paged = None "Please use `--attention-backend flashinfer`."
self.flashinfer_decode_wrapper = None
return
if not _grouped_size_compiled_for_decode_kernels(
self.model_config.num_attention_heads // self.tp_size,
self.model_config.get_num_kv_heads(self.tp_size),
):
use_tensor_cores = True
else:
use_tensor_cores = False
if self.sliding_window_size is None:
self.flashinfer_workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device="cuda",
)
self.flashinfer_prefill_wrapper_ragged = (
BatchPrefillWithRaggedKVCacheWrapper(
self.flashinfer_workspace_buffer, "NHD"
)
)
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer, "NHD"
)
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer,
"NHD",
use_tensor_cores=use_tensor_cores,
) )
self.attn_backend = TritonAttnBackend(self)
else: else:
self.flashinfer_workspace_buffer = torch.empty( raise ValueError(
global_config.flashinfer_workspace_size, f"Invalid attention backend: {self.server_args.attention_backend}"
dtype=torch.uint8,
device="cuda",
) )
self.flashinfer_prefill_wrapper_ragged = None
self.flashinfer_prefill_wrapper_paged = []
self.flashinfer_decode_wrapper = []
for i in range(2):
self.flashinfer_prefill_wrapper_paged.append(
BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer, "NHD"
)
)
self.flashinfer_decode_wrapper.append(
BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer,
"NHD",
use_tensor_cores=use_tensor_cores,
)
)
def init_cuda_graphs(self): def init_cuda_graphs(self):
"""Capture cuda graphs.""" """Capture cuda graphs."""
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
self.cuda_graph_runner = None
if not self.is_generation: if not self.is_generation:
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
return return
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner if self.server_args.disable_cuda_graph:
return
if ( if self.server_args.attention_backend != "flashinfer":
self.server_args.disable_cuda_graph logger.warning(
or self.server_args.attention_backend != "flashinfer" f"Cuda graph is not supported for attention backend: {self.server_args.attention_backend}"
): )
self.cuda_graph_runner = None
return return
logger.info("Capture cuda graph begin. This can take up to several minutes.") logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = CudaGraphRunner(self)
if self.server_args.disable_cuda_graph_padding:
batch_size_list = list(range(1, 32)) + [64, 128]
else:
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
self.cuda_graph_runner = CudaGraphRunner(
self,
max_batch_size_to_capture=max(batch_size_list),
use_torch_compile=self.server_args.enable_torch_compile,
disable_padding=self.server_args.disable_cuda_graph_padding,
)
try:
self.cuda_graph_runner.capture(batch_size_list)
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n"
"Possible solutions:\n"
"1. disable cuda graph by --disable-cuda-graph\n"
"2. set --mem-fraction-static to a smaller value\n"
"3. disable torch compile by not using --enable-torch-compile\n"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
)
@torch.inference_mode() @torch.inference_mode()
def forward_decode(self, batch: ScheduleBatch): def forward_decode(self, batch: ScheduleBatch):
......
...@@ -143,18 +143,16 @@ class SamplingBatchInfo: ...@@ -143,18 +143,16 @@ class SamplingBatchInfo:
self.linear_penalties = penalizer.apply(self.linear_penalties) self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self, batch: ScheduleBatch): def update_regex_vocab_mask(self, batch: ScheduleBatch):
bs, reqs = batch.batch_size(), batch.reqs has_regex = any(req.regex_fsm is not None for req in batch.reqs)
device = "cuda"
has_regex = any(req.regex_fsm is not None for req in reqs)
# Reset the vocab mask # Reset the vocab mask
self.vocab_mask = None self.vocab_mask = None
if has_regex: if has_regex:
self.vocab_mask = torch.zeros( self.vocab_mask = torch.zeros(
bs, self.vocab_size, dtype=torch.bool, device=device batch.batch_size(), self.vocab_size, dtype=torch.bool, device="cuda"
) )
for i, req in enumerate(reqs): for i, req in enumerate(batch.reqs):
if req.regex_fsm is not None: if req.regex_fsm is not None:
self.vocab_mask[i].fill_(1) self.vocab_mask[i].fill_(1)
self.vocab_mask[i][ self.vocab_mask[i][
......
...@@ -335,23 +335,19 @@ def launch_server( ...@@ -335,23 +335,19 @@ def launch_server(
return return
# Launch processes # Launch processes
tokenizer_manager = TokenizerManager(server_args, port_args)
if server_args.chat_template:
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False) pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
if server_args.dp_size == 1: if server_args.dp_size == 1:
start_controller_process = start_controller_process_single start_controller_process = start_controller_process_single
else: else:
start_controller_process = start_controller_process_multi start_controller_process = start_controller_process_multi
proc_controller = mp.Process( proc_controller = mp.Process(
target=start_controller_process, target=start_controller_process,
args=(server_args, port_args, pipe_controller_writer), args=(server_args, port_args, pipe_controller_writer),
) )
proc_controller.start() proc_controller.start()
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
proc_detoken = mp.Process( proc_detoken = mp.Process(
target=start_detokenizer_process, target=start_detokenizer_process,
args=( args=(
...@@ -362,6 +358,10 @@ def launch_server( ...@@ -362,6 +358,10 @@ def launch_server(
) )
proc_detoken.start() proc_detoken.start()
tokenizer_manager = TokenizerManager(server_args, port_args)
if server_args.chat_template:
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
# Wait for the model to finish loading # Wait for the model to finish loading
controller_init_state = pipe_controller_reader.recv() controller_init_state = pipe_controller_reader.recv()
detoken_init_state = pipe_detoken_reader.recv() detoken_init_state = pipe_detoken_reader.recv()
......
...@@ -83,8 +83,8 @@ class ServerArgs: ...@@ -83,8 +83,8 @@ class ServerArgs:
json_model_override_args: str = "{}" json_model_override_args: str = "{}"
# Optimization/debug options # Optimization/debug options
attention_backend: str = "flashinfer" attention_backend: Optional[str] = None
sampling_backend: str = "flashinfer" sampling_backend: Optional[str] = None
disable_flashinfer: bool = False disable_flashinfer: bool = False
disable_flashinfer_sampling: bool = False disable_flashinfer_sampling: bool = False
...@@ -148,6 +148,17 @@ class ServerArgs: ...@@ -148,6 +148,17 @@ class ServerArgs:
) )
self.sampling_backend = "pytorch" self.sampling_backend = "pytorch"
# Default kernel backends
if self.enable_mla:
logger.info("MLA optimization is tunred on. Use triton backend.")
self.attention_backend = "triton"
if self.attention_backend is None:
self.attention_backend = "flashinfer"
if self.sampling_backend is None:
self.sampling_backend = "flashinfer"
# Model-specific patches # Model-specific patches
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path: if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
logger.info( logger.info(
......
...@@ -55,8 +55,8 @@ class TestCreateKvIndices(unittest.TestCase): ...@@ -55,8 +55,8 @@ class TestCreateKvIndices(unittest.TestCase):
paged_kernel_lens, paged_kernel_lens,
kv_indptr, kv_indptr,
None, None,
req_to_token.size(1),
kv_indices_triton, kv_indices_triton,
req_to_token.size(1),
) )
# Check # Check
......
...@@ -19,7 +19,8 @@ class TestServingThroughput(unittest.TestCase): ...@@ -19,7 +19,8 @@ class TestServingThroughput(unittest.TestCase):
other_args = [] other_args = []
if disable_radix_cache: if disable_radix_cache:
other_args.append("--disable-radix-cache") other_args.append("--disable-radix-cache")
other_args.extend(["--attention-backend", attention_backend]) if attention_backend:
other_args.extend(["--attention-backend", attention_backend])
other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)]) other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)])
other_args.extend(["--tensor-parallel-size", "2"]) other_args.extend(["--tensor-parallel-size", "2"])
......
...@@ -19,7 +19,8 @@ class TestServingThroughput(unittest.TestCase): ...@@ -19,7 +19,8 @@ class TestServingThroughput(unittest.TestCase):
other_args = [] other_args = []
if disable_radix_cache: if disable_radix_cache:
other_args.append("--disable-radix-cache") other_args.append("--disable-radix-cache")
other_args.extend(["--attention-backend", attention_backend]) if attention_backend:
other_args.extend(["--attention-backend", attention_backend])
other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)]) other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)])
model = DEFAULT_MODEL_NAME_FOR_TEST model = DEFAULT_MODEL_NAME_FOR_TEST
......
...@@ -96,23 +96,17 @@ class TestExtendAttention(unittest.TestCase): ...@@ -96,23 +96,17 @@ class TestExtendAttention(unittest.TestCase):
v_buffer, v_buffer,
req_to_tokens, req_to_tokens,
b_req_idx, b_req_idx,
b_start_loc,
b_seq_len, b_seq_len,
b_seq_len_prefix,
b_start_loc_extend,
b_seq_len_extend, b_seq_len_extend,
max_len_in_batch, b_start_loc_extend,
max_len_extend, max_len_extend,
) )
redundant_attention( redundant_attention(
q_extend, q_extend,
k_extend,
v_extend,
o_redundant, o_redundant,
k_buffer, k_buffer,
v_buffer, v_buffer,
req_to_tokens,
b_req_idx, b_req_idx,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment