Unverified Commit 6d0fa73e authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify flashinfer utilities (#1704)

parent 9e0dac1a
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import torch
from torch import nn from torch import nn
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...@@ -18,13 +19,13 @@ class AttentionBackend(ABC): ...@@ -18,13 +19,13 @@ class AttentionBackend(ABC):
raise NotImplementedError() raise NotImplementedError()
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices, seq_lens self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
): ):
"""Init the metadata for a forward pass for capturing a cuda graph.""" """Init the metadata for a forward pass for capturing a cuda graph."""
raise NotImplementedError() raise NotImplementedError()
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices, seq_lens self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
): ):
"""Init the metadata for a forward pass for replying a cuda graph.""" """Init the metadata for a forward pass for replying a cuda graph."""
raise NotImplementedError() raise NotImplementedError()
...@@ -33,17 +34,38 @@ class AttentionBackend(ABC): ...@@ -33,17 +34,38 @@ class AttentionBackend(ABC):
"""Get the fill value for padded seq lens. Typically, it is 0 or 1.""" """Get the fill value for padded seq lens. Typically, it is 0 or 1."""
raise NotImplementedError() raise NotImplementedError()
def forward(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: nn.Module,
forward_batch: ForwardBatch,
):
"""Run forward on an attention layer.""" """Run forward on an attention layer."""
if forward_batch.forward_mode.is_decode(): if forward_batch.forward_mode.is_decode():
return self.forward_decode(q, k, v, layer, forward_batch) return self.forward_decode(q, k, v, layer, forward_batch)
else: else:
return self.forward_extend(q, k, v, layer, forward_batch) return self.forward_extend(q, k, v, layer, forward_batch)
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: nn.Module,
forward_batch: ForwardBatch,
):
"""Run a forward for decode.""" """Run a forward for decode."""
raise NotImplementedError() raise NotImplementedError()
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: nn.Module,
forward_batch: ForwardBatch,
):
"""Run a forward for extend.""" """Run a forward for extend."""
raise NotImplementedError() raise NotImplementedError()
...@@ -134,7 +134,7 @@ class DoubleSparseAttnBackend(AttentionBackend): ...@@ -134,7 +134,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
) )
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices, seq_lens self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
): ):
self.forward_metadata = ( self.forward_metadata = (
self.cuda_graph_start_loc, self.cuda_graph_start_loc,
...@@ -144,7 +144,7 @@ class DoubleSparseAttnBackend(AttentionBackend): ...@@ -144,7 +144,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
) )
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices, seq_lens self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
): ):
self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
......
...@@ -7,18 +7,17 @@ FlashInfer is faster and Triton is easier to customize. ...@@ -7,18 +7,17 @@ FlashInfer is faster and Triton is easier to customize.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode. Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
""" """
from enum import Enum, auto
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
import torch.nn as nn import torch.nn as nn
import triton
import triton.language as tl
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.attention.flashinfer_utils import ( from sglang.srt.model_executor.forward_batch_info import ForwardBatch
WrapperDispatch,
update_flashinfer_indices,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils import is_flashinfer_available
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -34,13 +33,18 @@ if is_flashinfer_available(): ...@@ -34,13 +33,18 @@ if is_flashinfer_available():
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
class WrapperDispatch(Enum):
SLIDING_WINDOW = auto()
CROSS_ATTENTION = auto()
class FlashInferAttnBackend(AttentionBackend): class FlashInferAttnBackend(AttentionBackend):
"""Flashinfer attention kernels.""" """Flashinfer attention kernels."""
def __init__(self, model_runner: ModelRunner): def __init__(self, model_runner: ModelRunner):
super().__init__() super().__init__()
self.model_runner = model_runner
# Parse constants
if not _grouped_size_compiled_for_decode_kernels( if not _grouped_size_compiled_for_decode_kernels(
model_runner.model_config.num_attention_heads // model_runner.tp_size, model_runner.model_config.num_attention_heads // model_runner.tp_size,
model_runner.model_config.get_num_kv_heads(model_runner.tp_size), model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
...@@ -48,27 +52,43 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -48,27 +52,43 @@ class FlashInferAttnBackend(AttentionBackend):
self.decode_use_tensor_cores = True self.decode_use_tensor_cores = True
else: else:
self.decode_use_tensor_cores = False self.decode_use_tensor_cores = False
self.max_context_len = model_runner.model_config.context_len
self.workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device="cuda",
)
assert not ( assert not (
model_runner.sliding_window_size is not None model_runner.sliding_window_size is not None
and model_runner.has_cross_attention and model_runner.has_cross_attention
), "Sliding window and cross attention are not supported together" ), "Sliding window and cross attention are not supported together"
self.num_wrappers = 1
self.dispatch_reason = None
if model_runner.sliding_window_size is not None: if model_runner.sliding_window_size is not None:
self.num_wrappers = 2 self.num_wrappers = 2
self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
elif model_runner.has_cross_attention: elif model_runner.has_cross_attention:
self.num_wrappers = 2 self.num_wrappers = 2
self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
else:
self.num_wrappers = 1
self.dispatch_reason = None
# Allocate buffers
self.workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device=model_runner.device,
)
max_bs = model_runner.req_to_token_pool.size
self.kv_indptr = [
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
for _ in range(self.num_wrappers)
]
self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)
self.qo_indptr = [
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
for _ in range(self.num_wrappers)
]
# Create wrappers
# NOTE: we do not use ragged attention when there are multiple wrappers # NOTE: we do not use ragged attention when there are multiple wrappers
self.prefill_wrapper_ragged = ( self.prefill_wrapper_ragged = (
BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD") BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
...@@ -92,26 +112,23 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -92,26 +112,23 @@ class FlashInferAttnBackend(AttentionBackend):
) )
) )
# Create indices updater
self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
model_runner, self
)
# Other metadata
self.forward_metadata = None self.forward_metadata = None
self.cuda_graph_metadata = {} self.cuda_graph_metadata = {}
def _get_wrapper_idx(self, layer: nn.Module):
if self.num_wrappers == 1:
return 0
if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
return layer.sliding_window_size == -1
if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
return layer.is_cross_attention
raise ValueError(f"Unknown dispatch reason: {self.dispatch_reason}")
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode(): if forward_batch.forward_mode.is_decode():
prefix_lens = None self.indices_updater_decode.update(
use_ragged = False forward_batch.req_pool_indices,
extend_no_prefix = False forward_batch.seq_lens,
total_num_tokens = None )
self.forward_metadata = (self.decode_wrappers,)
else: else:
prefix_lens = forward_batch.extend_prefix_lens prefix_lens = forward_batch.extend_prefix_lens
...@@ -123,48 +140,32 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -123,48 +140,32 @@ class FlashInferAttnBackend(AttentionBackend):
): ):
use_ragged = True use_ragged = True
total_num_tokens = torch.sum(forward_batch.seq_lens).item()
extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item() extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
update_flashinfer_indices( self.indices_updater_prefill.update(
forward_batch.forward_mode,
self.model_runner,
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
forward_batch.seq_lens, forward_batch.seq_lens,
prefix_lens, prefix_lens,
use_ragged=use_ragged, use_ragged,
) )
self.forward_metadata = ( self.forward_metadata = (
use_ragged, use_ragged,
extend_no_prefix, extend_no_prefix,
total_num_tokens,
self.decode_wrappers,
) )
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int):
self.cuda_graph_kv_indptr = torch.zeros( cuda_graph_kv_indices = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device="cuda" (max_bs * self.max_context_len,),
)
self.cuda_graph_kv_indices = torch.zeros(
(max_bs * self.model_runner.model_config.context_len,),
dtype=torch.int32, dtype=torch.int32,
device="cuda", device="cuda",
) )
self.cuda_graph_kv_last_page_len = torch.ones( self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [
(max_bs,), dtype=torch.int32, device="cuda" cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
)
# NOTE: the buffers are always in the form of list
self.cuda_graph_kv_indptr = [self.cuda_graph_kv_indptr] + [
self.cuda_graph_kv_indptr.clone() for _ in range(self.num_wrappers - 1)
]
self.cuda_graph_kv_indices = [self.cuda_graph_kv_indices] + [
self.cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
] ]
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices, seq_lens self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
): ):
decode_wrappers = [] decode_wrappers = []
for i in range(self.num_wrappers): for i in range(self.num_wrappers):
...@@ -174,35 +175,21 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -174,35 +175,21 @@ class FlashInferAttnBackend(AttentionBackend):
"NHD", "NHD",
use_cuda_graph=True, use_cuda_graph=True,
use_tensor_cores=self.decode_use_tensor_cores, use_tensor_cores=self.decode_use_tensor_cores,
paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1], paged_kv_indptr_buffer=self.kv_indptr[i][: bs + 1],
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs], paged_kv_last_page_len_buffer=self.kv_last_page_len[:bs],
)
) )
update_flashinfer_indices(
ForwardMode.DECODE,
self.model_runner,
req_pool_indices,
seq_lens,
None,
decode_wrappers,
) )
self.indices_updater_decode.update(req_pool_indices, seq_lens, decode_wrappers)
self.cuda_graph_metadata[bs] = decode_wrappers self.cuda_graph_metadata[bs] = decode_wrappers
self.forward_metadata = (decode_wrappers,)
self.forward_metadata = (False, False, None, decode_wrappers)
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices, seq_lens self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
): ):
update_flashinfer_indices( self.indices_updater_decode.update(
ForwardMode.DECODE, req_pool_indices[:bs], seq_lens[:bs], self.cuda_graph_metadata[bs]
self.model_runner,
req_pool_indices[:bs],
seq_lens[:bs],
None,
self.cuda_graph_metadata[bs],
) )
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
...@@ -213,7 +200,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -213,7 +200,7 @@ class FlashInferAttnBackend(AttentionBackend):
self._get_wrapper_idx(layer) self._get_wrapper_idx(layer)
] ]
use_ragged, extend_no_prefix, _, _ = self.forward_metadata use_ragged, extend_no_prefix = self.forward_metadata
if not use_ragged: if not use_ragged:
if k is not None: if k is not None:
...@@ -259,7 +246,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -259,7 +246,7 @@ class FlashInferAttnBackend(AttentionBackend):
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
decode_wrapper = self.forward_metadata[-1][self._get_wrapper_idx(layer)] decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
if k is not None: if k is not None:
assert v is not None assert v is not None
...@@ -275,3 +262,285 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -275,3 +262,285 @@ class FlashInferAttnBackend(AttentionBackend):
) )
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def _get_wrapper_idx(self, layer: nn.Module):
if self.num_wrappers == 1:
return 0
if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
return layer.sliding_window_size == -1
if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
return layer.is_cross_attention
raise ValueError(f"Unknown dispatch reason: {self.dispatch_reason}")
class FlashInferIndicesUpdaterDecode:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
# Constants
self.num_qo_heads = (
model_runner.model_config.num_attention_heads // model_runner.tp_size
)
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
model_runner.tp_size
)
self.head_dim = model_runner.model_config.head_dim
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
self.sliding_window_size = model_runner.sliding_window_size
# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.decode_wrappers = attn_backend.decode_wrappers
# Dispatch
if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
self.update = self.update_sliding_window
elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
self.update = self.update_cross_attention
else:
assert attn_backend.num_wrappers == 1
self.update = self.update_single_wrapper
def update_single_wrapper(self, req_pool_indices, seq_lens, decode_wrappers=None):
decode_wrappers = decode_wrappers or self.decode_wrappers
self.call_begin_forward(
decode_wrappers[0], req_pool_indices, seq_lens, self.kv_indptr[0], None
)
def update_sliding_window(self, req_pool_indices, seq_lens, decode_wrappers=None):
decode_wrappers = decode_wrappers or self.decode_wrappers
for wrapper_id in range(2):
if wrapper_id == 0:
# Sliding window attention
paged_kernel_lens = torch.minimum( # TODO: replace this with clamp
seq_lens,
torch.tensor(self.sliding_window_size + 1),
)
else:
# Full attention
paged_kernel_lens = seq_lens
kv_start_idx = seq_lens - paged_kernel_lens
self.call_begin_forward(
decode_wrappers[wrapper_id],
req_pool_indices,
paged_kernel_lens,
self.kv_indptr[wrapper_id],
kv_start_idx,
)
def update_cross_attention(self):
raise NotImplementedError()
def call_begin_forward(
self, wrapper, req_pool_indices, paged_kernel_lens, kv_indptr, kv_start_idx
):
bs = len(req_pool_indices)
kv_indptr = kv_indptr[: bs + 1]
# TODO: optimize the blocking call on kv_indptr[-1]
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
kv_start_idx,
kv_indices,
self.max_context_len,
)
wrapper.end_forward()
wrapper.begin_forward(
kv_indptr,
kv_indices,
self.kv_last_page_len[:bs],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
1,
data_type=self.data_type,
q_data_type=self.q_data_type,
)
class FlashInferIndicesUpdaterPrefill:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
# Constants
self.num_qo_heads = (
model_runner.model_config.num_attention_heads // model_runner.tp_size
)
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
model_runner.tp_size
)
self.head_dim = model_runner.model_config.head_dim
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
self.sliding_window_size = model_runner.sliding_window_size
# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
self.qo_indptr = attn_backend.qo_indptr
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.wrapper_ragged = attn_backend.prefill_wrapper_ragged
self.wrappers_paged = attn_backend.prefill_wrappers_paged
# Dispatch
if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
self.update = self.update_sliding_window
elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
self.update = self.update_cross_attention
else:
assert attn_backend.num_wrappers == 1
self.update = self.update_single_wrapper
def update_single_wrapper(
self, req_pool_indices, seq_lens, prefix_lens, use_ragged
):
if use_ragged:
paged_kernel_lens = prefix_lens
else:
paged_kernel_lens = seq_lens
self.call_begin_forward(
self.wrapper_ragged,
self.wrappers_paged[0],
req_pool_indices,
paged_kernel_lens,
seq_lens,
prefix_lens,
None,
self.kv_indptr[0],
self.qo_indptr[0],
use_ragged,
)
def update_sliding_window(
self, req_pool_indices, seq_lens, prefix_lens, use_ragged
):
for wrapper_id in range(2):
if wrapper_id == 0:
# window attention use paged only
paged_kernel_lens = torch.minimum(
seq_lens,
torch.tensor(self.sliding_window_size) + seq_lens - prefix_lens,
)
else:
# full attention
paged_kernel_lens = seq_lens
kv_start_idx = seq_lens - paged_kernel_lens
self.call_begin_forward(
self.wrapper_ragged,
self.wrappers_paged[wrapper_id],
req_pool_indices,
paged_kernel_lens,
seq_lens,
prefix_lens,
kv_start_idx,
self.kv_indptr[wrapper_id],
self.qo_indptr[wrapper_id],
use_ragged,
)
def update_cross_attention(self):
raise NotImplementedError()
def call_begin_forward(
self,
wrapper_ragged,
wrapper_paged,
req_pool_indices,
paged_kernel_lens,
seq_lens,
prefix_lens,
kv_start_idx,
kv_indptr,
qo_indptr,
use_ragged,
):
bs = len(req_pool_indices)
kv_indptr = kv_indptr[: bs + 1]
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
kv_start_idx,
kv_indices,
self.max_context_len,
)
qo_indptr = qo_indptr[: bs + 1]
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
# extend part
if use_ragged:
wrapper_ragged.end_forward()
wrapper_ragged.begin_forward(
qo_indptr,
qo_indptr,
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
)
# cached part
wrapper_paged.end_forward()
wrapper_paged.begin_forward(
qo_indptr,
kv_indptr,
kv_indices,
self.kv_last_page_len[:bs],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
1,
)
@triton.jit
def create_flashinfer_kv_indices_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_indptr,
kv_start_idx,
kv_indices_ptr,
max_context_len: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(axis=0)
req_pool_index = tl.load(req_pool_indices_ptr + pid)
kv_indices_offset = tl.load(kv_indptr + pid)
kv_start = 0
kv_end = 0
if kv_start_idx:
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
kv_end = kv_start
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
req_to_token_ptr += req_pool_index * max_context_len
kv_indices_ptr += kv_indices_offset
ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
st_offset = tl.arange(0, BLOCK_SIZE)
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for _ in range(num_loop):
mask = ld_offset < kv_end
data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
tl.store(kv_indices_ptr + st_offset, data, mask=mask)
ld_offset += BLOCK_SIZE
st_offset += BLOCK_SIZE
from enum import Enum, auto
import torch
import triton
import triton.language as tl
class WrapperDispatch(Enum):
SLIDING_WINDOW = auto()
CROSS_ATTENTION = auto()
@triton.jit
def create_flashinfer_kv_indices_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_indptr,
kv_start_idx,
kv_indices_ptr,
max_context_len: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(axis=0)
req_pool_index = tl.load(req_pool_indices_ptr + pid)
kv_indices_offset = tl.load(kv_indptr + pid)
kv_start = 0
kv_end = 0
if kv_start_idx:
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
kv_end = kv_start
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
req_to_token_ptr += req_pool_index * max_context_len
kv_indices_ptr += kv_indices_offset
ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
st_offset = tl.arange(0, BLOCK_SIZE)
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for _ in range(num_loop):
mask = ld_offset < kv_end
data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
tl.store(kv_indices_ptr + st_offset, data, mask=mask)
ld_offset += BLOCK_SIZE
st_offset += BLOCK_SIZE
class FlashinferUpdater:
def __init__(
self,
forward_mode,
model_runner,
req_pool_indices,
seq_lens,
prefix_lens,
decode_wrappers=None,
use_ragged=False,
):
self.forward_mode = forward_mode
self.model_runner = model_runner
self.req_pool_indices = req_pool_indices
self.seq_lens = seq_lens
self.prefix_lens = prefix_lens
self.use_ragged = use_ragged
self.num_qo_heads = (
model_runner.model_config.num_attention_heads // model_runner.tp_size
)
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
model_runner.tp_size
)
self.head_dim = model_runner.model_config.head_dim
self.batch_size = len(req_pool_indices)
self.decode_wrappers = (
decode_wrappers or self.model_runner.attn_backend.decode_wrappers
)
self.prefill_wrapper_ragged = (
self.model_runner.attn_backend.prefill_wrapper_ragged
)
self.prefill_wrappers_paged = (
self.model_runner.attn_backend.prefill_wrappers_paged
)
self.kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)
def _update_decode_indices(self, decode_wrapper):
assert not isinstance(decode_wrapper, list)
decode_wrapper.end_forward()
decode_wrapper.begin_forward(
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
1,
data_type=self.model_runner.kv_cache_dtype,
q_data_type=self.model_runner.dtype,
)
def _update_extend_indices(self, ragged_wrapper, paged_wrapper):
assert not isinstance(paged_wrapper, list)
assert not isinstance(ragged_wrapper, list)
# extend part
qo_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0)
if self.use_ragged:
ragged_wrapper.end_forward()
ragged_wrapper.begin_forward(
qo_indptr,
qo_indptr,
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
)
# cached part
paged_wrapper.end_forward()
paged_wrapper.begin_forward(
qo_indptr,
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
1,
)
def _get_indices(self, dispatch_reason: WrapperDispatch = None, wrapper_id=0):
if dispatch_reason is None:
if self.use_ragged:
paged_kernel_lens = self.prefix_lens
else:
paged_kernel_lens = self.seq_lens
self.kv_start_idx = None
elif dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
if wrapper_id == 0:
# window attention use paged only
if self.forward_mode.is_decode():
paged_kernel_lens = torch.minimum(
self.seq_lens,
torch.tensor(self.model_runner.sliding_window_size + 1),
)
else:
paged_kernel_lens = torch.minimum(
self.seq_lens,
torch.tensor(self.model_runner.sliding_window_size)
+ self.seq_lens
- self.prefix_lens,
)
else:
# full attention
paged_kernel_lens = self.seq_lens
self.kv_start_idx = self.seq_lens - paged_kernel_lens
self.kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
self.kv_indices = torch.empty(
self.kv_indptr[-1], dtype=torch.int32, device="cuda"
)
create_flashinfer_kv_indices_triton[(self.batch_size,)](
self.model_runner.req_to_token_pool.req_to_token,
self.req_pool_indices,
paged_kernel_lens,
self.kv_indptr,
self.kv_start_idx,
self.kv_indices,
self.model_runner.req_to_token_pool.req_to_token.size(1),
)
def _update_indicess_single_wrapper(self):
self._get_indices()
if self.forward_mode.is_decode():
self._update_decode_indices(self.decode_wrappers[0])
else:
self._update_extend_indices(
self.prefill_wrapper_ragged,
self.prefill_wrappers_paged[0],
)
def _update_indices_cross_attention(self):
pass
def _update_indices_sliding_window(self):
assert self.use_ragged is False
for wrapper_id in range(2):
self._get_indices(WrapperDispatch.SLIDING_WINDOW, wrapper_id)
if self.forward_mode.is_decode():
self._update_decode_indices(self.decode_wrappers[wrapper_id])
else:
self._update_extend_indices(
None,
self.prefill_wrappers_paged[wrapper_id],
)
def update_flashinfer_indices(
forward_mode,
model_runner,
req_pool_indices,
seq_lens,
prefix_lens,
decode_wrappers=None,
use_ragged=False,
):
updater = FlashinferUpdater(
forward_mode,
model_runner,
req_pool_indices,
seq_lens,
prefix_lens,
decode_wrappers,
use_ragged,
)
dispatch_reason = model_runner.attn_backend.dispatch_reason
if dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
updater._update_indices_sliding_window()
elif dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
updater._update_indices_cross_attention()
else:
assert model_runner.attn_backend.num_wrappers == 1
updater._update_indicess_single_wrapper()
...@@ -81,7 +81,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -81,7 +81,7 @@ class TritonAttnBackend(AttentionBackend):
) )
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices, seq_lens self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
): ):
self.forward_metadata = ( self.forward_metadata = (
self.cuda_graph_start_loc, self.cuda_graph_start_loc,
...@@ -91,7 +91,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -91,7 +91,7 @@ class TritonAttnBackend(AttentionBackend):
) )
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices, seq_lens self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
): ):
self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
......
...@@ -744,7 +744,6 @@ class ScheduleBatch: ...@@ -744,7 +744,6 @@ class ScheduleBatch:
self.forward_mode = ForwardMode.DECODE self.forward_mode = ForwardMode.DECODE
self.input_ids = self.output_ids self.input_ids = self.output_ids
self.seq_lens.add_(1)
self.output_ids = None self.output_ids = None
if self.sampling_info.penalizer_orchestrator: if self.sampling_info.penalizer_orchestrator:
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens( self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
...@@ -755,9 +754,10 @@ class ScheduleBatch: ...@@ -755,9 +754,10 @@ class ScheduleBatch:
bs = len(self.reqs) bs = len(self.reqs)
self.out_cache_loc = self.alloc_token_slots(bs) self.out_cache_loc = self.alloc_token_slots(bs)
self.req_to_token_pool.req_to_token[ self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
self.req_pool_indices, self.seq_lens - 1 self.out_cache_loc
] = self.out_cache_loc )
self.seq_lens.add_(1)
def filter_batch( def filter_batch(
self, self,
......
...@@ -134,9 +134,7 @@ class ForwardBatch: ...@@ -134,9 +134,7 @@ class ForwardBatch:
) )
# Init position information # Init position information
if ret.forward_mode.is_decode(): if not ret.forward_mode.is_decode():
ret.positions = (ret.seq_lens - 1).to(torch.int64)
else:
ret.positions = torch.tensor( ret.positions = torch.tensor(
np.concatenate( np.concatenate(
[ [
...@@ -164,7 +162,6 @@ class ForwardBatch: ...@@ -164,7 +162,6 @@ class ForwardBatch:
ret.req_to_token_pool = model_runner.req_to_token_pool ret.req_to_token_pool = model_runner.req_to_token_pool
ret.token_to_kv_pool = model_runner.token_to_kv_pool ret.token_to_kv_pool = model_runner.token_to_kv_pool
ret.attn_backend = model_runner.attn_backend ret.attn_backend = model_runner.attn_backend
model_runner.attn_backend.init_forward_metadata(ret)
# Init lora information # Init lora information
if model_runner.server_args.lora_paths is not None: if model_runner.server_args.lora_paths is not None:
......
...@@ -551,11 +551,14 @@ class ModelRunner: ...@@ -551,11 +551,14 @@ class ModelRunner:
): ):
return self.cuda_graph_runner.replay(forward_batch) return self.cuda_graph_runner.replay(forward_batch)
forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
self.attn_backend.init_forward_metadata(forward_batch)
return self.model.forward( return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch forward_batch.input_ids, forward_batch.positions, forward_batch
) )
def forward_extend(self, forward_batch: ForwardBatch): def forward_extend(self, forward_batch: ForwardBatch):
self.attn_backend.init_forward_metadata(forward_batch)
if self.is_generation: if self.is_generation:
return self.model.forward( return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch forward_batch.input_ids, forward_batch.positions, forward_batch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment