Unverified Commit fec185ce authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Refactor attention backend (#1381)

parent c03cece4
from __future__ import annotations
"""
Support different attention backends.
Now there are two backends: FlashInfer and Triton.
FlashInfer is faster and Triton is easier to customize.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.cascade import merge_state
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
from sglang.global_config import global_config
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
class AttentionBackend(ABC):
"""The base class of attention backends"""
@abstractmethod
def init_forward_metadata(
self, batch: ScheduleBatch, input_metadata: InputMetadata
):
pass
def forward(self, q, k, v, layer, input_metadata: InputMetadata):
if input_metadata.forward_mode.is_decode():
return self.forward_decode(q, k, v, layer, input_metadata)
else:
return self.forward_extend(q, k, v, layer, input_metadata)
class FlashInferAttnBackend(AttentionBackend):
"""Flashinfer attention kernels."""
def __init__(self, model_runner: ModelRunner):
super().__init__()
self.model_runner = model_runner
if not _grouped_size_compiled_for_decode_kernels(
model_runner.model_config.num_attention_heads // model_runner.tp_size,
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
):
self.decode_use_tensor_cores = True
else:
self.decode_use_tensor_cores = False
self.workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device="cuda",
)
if model_runner.sliding_window_size is None:
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD"
)
self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD"
)
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_tensor_cores=self.decode_use_tensor_cores,
)
else:
# Two wrappers: one for sliding window attention and one for full attention.
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
self.prefill_wrapper_ragged = None
self.prefill_wrapper_paged = []
self.decode_wrapper = []
for _ in range(2):
self.prefill_wrapper_paged.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
)
self.decode_wrapper.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_tensor_cores=self.decode_use_tensor_cores,
)
)
self.forward_metadata = None
self.cuda_graph_metadata = {}
def init_forward_metadata(
self, batch: ScheduleBatch, input_metadata: InputMetadata
):
if input_metadata.forward_mode.is_decode():
prefix_lens = None
use_ragged = False
total_num_tokens = None
else:
prefix_lens = input_metadata.extend_prefix_lens
# Some heuristics to check whether to use ragged forward
use_ragged = False
if (
int(torch.sum(input_metadata.seq_lens)) > 4096
and self.model_runner.sliding_window_size is None
):
use_ragged = True
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
update_flashinfer_indices(
input_metadata.forward_mode,
self.model_runner,
input_metadata.req_pool_indices,
input_metadata.seq_lens,
prefix_lens,
use_ragged=use_ragged,
)
self.forward_metadata = (use_ragged, total_num_tokens, self.decode_wrapper)
def init_cuda_graph_state(self, max_bs: int):
self.cuda_graph_kv_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device="cuda"
)
self.cuda_graph_kv_indices = torch.zeros(
(max_bs * self.model_runner.model_config.context_len,),
dtype=torch.int32,
device="cuda",
)
self.cuda_graph_kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device="cuda"
)
if self.model_runner.sliding_window_size is not None:
self.cuda_graph_kv_indptr = [
self.cuda_graph_kv_indptr,
self.cuda_graph_kv_indptr.clone(),
]
self.cuda_graph_kv_indices = [
self.cuda_graph_kv_indices,
self.cuda_graph_kv_indices.clone(),
]
def capture_cuda_graph_init(self, bs: int, req_pool_indices, seq_lens):
if self.model_runner.sliding_window_size is None:
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_cuda_graph=True,
use_tensor_cores=self.decode_use_tensor_cores,
paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[: bs + 1],
paged_kv_indices_buffer=self.cuda_graph_kv_indices,
paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs],
)
else:
decode_wrapper = []
for i in range(2):
decode_wrapper.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_cuda_graph=True,
use_tensor_cores=self.decode_use_tensor_cores,
paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1],
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[
:bs
],
)
)
update_flashinfer_indices(
ForwardMode.DECODE,
self.model_runner,
req_pool_indices,
seq_lens,
None,
decode_wrapper,
)
self.cuda_graph_metadata[bs] = decode_wrapper
self.forward_metadata = (False, None, decode_wrapper)
def replay_cuda_graph_init(self, bs: int, req_pool_indices, seq_lens):
update_flashinfer_indices(
ForwardMode.DECODE,
self.model_runner,
req_pool_indices[:bs],
seq_lens[:bs],
None,
self.cuda_graph_metadata[bs],
)
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
if not isinstance(self.prefill_wrapper_paged, list):
prefill_wrapper_paged = self.prefill_wrapper_paged
else:
if layer.sliding_window_size != -1:
prefill_wrapper_paged = self.prefill_wrapper_paged[0]
else:
prefill_wrapper_paged = self.prefill_wrapper_paged[1]
use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
if not use_ragged:
if k is not None:
assert v is not None
input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v
)
o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=True,
sm_scale=layer.scaling,
window_left=layer.sliding_window_size,
logits_soft_cap=layer.logit_cap,
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim),
v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
)
if input_metadata.extend_no_prefix:
o = o1
else:
o2, s2 = prefill_wrapper_paged.forward_return_lse(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=False,
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
)
o, _ = merge_state(o1, s1, o2, s2)
input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v
)
if total_num_tokens >= global_config.layer_sync_threshold:
torch.cuda.synchronize()
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
if isinstance(decode_wrapper, list):
if layer.sliding_window_size != -1:
decode_wrapper = decode_wrapper[0]
else:
decode_wrapper = decode_wrapper[1]
if k is not None:
assert v is not None
input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v
)
o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
class TritonAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
# Lazy import to avoid the initialization of cuda context
from sglang.srt.layers.triton_attention.decode_attention import (
decode_attention_fwd,
)
from sglang.srt.layers.triton_attention.extend_attention import (
extend_attention_fwd,
)
super().__init__()
self.decode_attention_fwd = decode_attention_fwd
self.extend_attention_fwd = extend_attention_fwd
self.forward_metadata = None
def init_forward_metadata(
self, batch: ScheduleBatch, input_metadata: InputMetadata
):
"""Init auxiliary variables for triton attention backend."""
if input_metadata.forward_mode.is_decode():
max_seq_len = torch.max(input_metadata.seq_lens).item()
start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32)
start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0)
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
max_extend_len = None
else:
start_loc = max_seq_len = total_num_tokens = None
prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
self.forward_metadata = start_loc, max_seq_len, max_extend_len, total_num_tokens
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)
input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v
)
start_loc, max_seq_len, max_extend_len, total_num_tokens = self.forward_metadata
self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k.contiguous(),
v.contiguous(),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
input_metadata.seq_lens,
input_metadata.extend_seq_lens,
input_metadata.extend_start_loc,
max_extend_len,
layer.scaling,
layer.logit_cap,
)
return o
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)
start_loc, max_seq_len, max_extend_len, total_num_tokens = self.forward_metadata
input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v
)
self.decode_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
start_loc,
input_metadata.seq_lens,
max_seq_len,
total_num_tokens,
layer.scaling,
layer.logit_cap,
)
return o
......@@ -10,8 +10,8 @@ def create_flashinfer_kv_indices_triton(
page_kernel_lens_ptr,
kv_indptr,
kv_start_idx,
max_context_len,
kv_indices_ptr,
max_context_len: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(axis=0)
......@@ -47,15 +47,15 @@ class FlashinferUpdater:
req_pool_indices,
seq_lens,
prefix_lens,
flashinfer_decode_wrapper=None,
flashinfer_use_ragged=False,
decode_wrapper=None,
use_ragged=False,
):
self.forward_mode = forward_mode
self.model_runner = model_runner
self.req_pool_indices = req_pool_indices
self.seq_lens = seq_lens
self.prefix_lens = prefix_lens
self.flashinfer_use_ragged = flashinfer_use_ragged
self.use_ragged = use_ragged
self.num_qo_heads = (
model_runner.model_config.num_attention_heads // model_runner.tp_size
......@@ -71,20 +71,17 @@ class FlashinferUpdater:
)
(
self.flashinfer_decode_wrapper,
self.flashinfer_prefill_wrapper_ragged,
self.flashinfer_prefill_wrapper_paged,
self.decode_wrapper,
self.prefill_wrapper_ragged,
self.prefill_wrapper_paged,
) = (
flashinfer_decode_wrapper,
self.model_runner.flashinfer_prefill_wrapper_ragged,
self.model_runner.flashinfer_prefill_wrapper_paged,
decode_wrapper or self.model_runner.attn_backend.decode_wrapper,
self.model_runner.attn_backend.prefill_wrapper_ragged,
self.model_runner.attn_backend.prefill_wrapper_paged,
)
# CUDA graph uses different flashinfer_decode_wrapper
if self.flashinfer_decode_wrapper is None:
self.flashinfer_decode_wrapper = self.model_runner.flashinfer_decode_wrapper
def _init_indices_no_window(self):
if self.flashinfer_use_ragged:
def _init_indices_no_sliding_window(self):
if self.use_ragged:
paged_kernel_lens = self.prefix_lens
else:
paged_kernel_lens = self.seq_lens
......@@ -103,13 +100,13 @@ class FlashinferUpdater:
paged_kernel_lens,
self.kv_indptr,
None,
self.model_runner.req_to_token_pool.req_to_token.size(1),
self.kv_indices,
self.model_runner.req_to_token_pool.req_to_token.size(1),
)
def _init_indices_window(self, wrapper_id):
# window attention use paged only
def _init_indices_sliding_window(self, wrapper_id):
if wrapper_id == 0:
# window attention use paged only
if self.forward_mode.is_decode():
paged_kernel_lens = torch.minimum(
self.seq_lens,
......@@ -123,6 +120,7 @@ class FlashinferUpdater:
- self.prefix_lens,
)
else:
# full attention
paged_kernel_lens = self.seq_lens
kv_start_idx = self.seq_lens - paged_kernel_lens
......@@ -139,8 +137,8 @@ class FlashinferUpdater:
paged_kernel_lens,
self.kv_indptr,
kv_start_idx,
self.model_runner.req_to_token_pool.req_to_token.size(1),
self.kv_indices,
self.model_runner.req_to_token_pool.req_to_token.size(1),
)
def _update_decode_indices(self, decode_wrapper):
......@@ -164,7 +162,7 @@ class FlashinferUpdater:
)
qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0)
if self.flashinfer_use_ragged:
if self.use_ragged:
ragged_wrapper.end_forward()
ragged_wrapper.begin_forward(
qo_indptr,
......@@ -187,28 +185,28 @@ class FlashinferUpdater:
1,
)
def update_indices_no_window(self):
self._init_indices_no_window()
def update_indices_no_sliding_window(self):
self._init_indices_no_sliding_window()
if self.forward_mode.is_decode():
self._update_decode_indices(self.flashinfer_decode_wrapper)
self._update_decode_indices(self.decode_wrapper)
else:
self._update_extend_indices(
self.flashinfer_prefill_wrapper_ragged,
self.flashinfer_prefill_wrapper_paged,
self.prefill_wrapper_ragged,
self.prefill_wrapper_paged,
)
def update_indices_window(self):
assert self.flashinfer_use_ragged is False
def update_indices_sliding_window(self):
assert self.use_ragged is False
for wrapper_id in range(2):
self._init_indices_window(wrapper_id)
self._init_indices_sliding_window(wrapper_id)
if self.forward_mode.is_decode():
self._update_decode_indices(self.flashinfer_decode_wrapper[wrapper_id])
self._update_decode_indices(self.decode_wrapper[wrapper_id])
else:
self._update_extend_indices(
None,
self.flashinfer_prefill_wrapper_paged[wrapper_id],
self.prefill_wrapper_paged[wrapper_id],
)
......@@ -218,20 +216,20 @@ def update_flashinfer_indices(
req_pool_indices,
seq_lens,
prefix_lens,
flashinfer_decode_wrapper=None,
flashinfer_use_ragged=False,
decode_wrapper=None,
use_ragged=False,
):
flashinfer_updater = FlashinferUpdater(
updater = FlashinferUpdater(
forward_mode,
model_runner,
req_pool_indices,
seq_lens,
prefix_lens,
flashinfer_decode_wrapper,
flashinfer_use_ragged,
decode_wrapper,
use_ragged,
)
if model_runner.sliding_window_size is None:
flashinfer_updater.update_indices_no_window()
updater.update_indices_no_sliding_window()
else:
flashinfer_updater.update_indices_window()
updater.update_indices_sliding_window()
......@@ -15,25 +15,14 @@ limitations under the License.
"""Radix attention."""
from typing import Optional
import torch
from flashinfer.cascade import merge_state
from torch import nn
from sglang.global_config import global_config
from sglang.srt.layers.triton_attention.decode_attention import decode_attention_fwd
from sglang.srt.layers.triton_attention.extend_attention import extend_attention_fwd
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.model_executor.model_runner import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata
class RadixAttention(nn.Module):
"""
The attention layer implementation.
Now it has two backends: FlashInfer and Triton.
FlashInfer is faster and Triton is easier to customize.
It supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""
def __init__(
......@@ -43,8 +32,8 @@ class RadixAttention(nn.Module):
scaling: float,
num_kv_heads: int,
layer_id: int,
sliding_window_size: Optional[int] = None,
logit_cap: int = -1,
sliding_window_size: int = -1,
logit_cap: float = 0.0,
v_head_dim: int = -1,
):
super().__init__()
......@@ -56,164 +45,14 @@ class RadixAttention(nn.Module):
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
self.scaling = scaling
self.layer_id = layer_id
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
self.sliding_window_size = sliding_window_size if sliding_window_size else -1
# Choose backend
if (
global_server_args_dict["attention_backend"] == "flashinfer"
and self.qk_head_dim == self.v_head_dim
):
self.extend_forward = self.extend_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer
elif global_server_args_dict["attention_backend"] == "triton":
self.extend_forward = self.extend_forward_triton
self.decode_forward = self.decode_forward_triton
else:
raise ValueError(
f"Invalid attention backend: {global_server_args_dict['attention_backend']}"
)
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
if self.qk_head_dim != self.v_head_dim:
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
else:
o = torch.empty_like(q)
self.store_kv_cache(k, v, input_metadata)
extend_attention_fwd(
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
k.contiguous(),
v.contiguous(),
o.view(-1, self.tp_q_head_num, self.v_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
input_metadata.triton_start_loc,
input_metadata.seq_lens,
input_metadata.triton_prefix_lens,
input_metadata.extend_start_loc,
input_metadata.extend_seq_lens,
input_metadata.triton_max_seq_len,
input_metadata.triton_max_extend_len,
sm_scale=self.scaling,
logit_cap=self.logit_cap,
)
return o
def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
if self.qk_head_dim != self.v_head_dim:
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
else:
o = torch.empty_like(q)
self.store_kv_cache(k, v, input_metadata)
decode_attention_fwd(
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
o.view(-1, self.tp_q_head_num, self.v_head_dim),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
input_metadata.triton_start_loc,
input_metadata.seq_lens,
input_metadata.triton_max_seq_len,
input_metadata.total_num_tokens,
sm_scale=self.scaling,
logit_cap=self.logit_cap,
)
return o
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
if self.sliding_window_size != -1:
prefill_wrapper_paged = prefill_wrapper_paged[0]
else:
if isinstance(prefill_wrapper_paged, list):
prefill_wrapper_paged = prefill_wrapper_paged[1]
if not input_metadata.flashinfer_use_ragged:
if k is not None:
assert v is not None
self.store_kv_cache(k, v, input_metadata)
o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
causal=True,
sm_scale=self.scaling,
window_left=self.sliding_window_size,
logits_soft_cap=self.logit_cap,
)
else:
o1, s1 = (
input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
causal=True,
sm_scale=self.scaling,
logits_soft_cap=self.logit_cap,
)
)
if input_metadata.extend_no_prefix:
o = o1
else:
o2, s2 = prefill_wrapper_paged.forward_return_lse(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
causal=False,
sm_scale=self.scaling,
logits_soft_cap=self.logit_cap,
)
o, _ = merge_state(o1, s1, o2, s2)
self.store_kv_cache(k, v, input_metadata)
if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
torch.cuda.synchronize()
return o.view(-1, self.tp_q_head_num * self.head_dim)
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
decode_wrapper = input_metadata.flashinfer_decode_wrapper
if self.sliding_window_size != -1:
decode_wrapper = decode_wrapper[0]
else:
if isinstance(decode_wrapper, list):
decode_wrapper = decode_wrapper[1]
if k is not None:
assert v is not None
self.store_kv_cache(k, v, input_metadata)
o = decode_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
sm_scale=self.scaling,
logits_soft_cap=self.logit_cap,
)
return o.view(-1, self.tp_q_head_num * self.head_dim)
self.logit_cap = logit_cap
self.sliding_window_size = sliding_window_size or -1
def forward(self, q, k, v, input_metadata: InputMetadata):
if k is not None:
# For cross-layer sharing, kv can be None
assert v is not None
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
if input_metadata.forward_mode.is_extend():
return self.extend_forward(q, k, v, input_metadata)
elif input_metadata.forward_mode.is_decode():
return self.decode_forward(q, k, v, input_metadata)
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
input_metadata.token_to_kv_pool.set_kv_buffer(
self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v
)
return input_metadata.attn_backend.forward(q, k, v, self, input_metadata)
......@@ -15,6 +15,7 @@ limitations under the License.
"""
Memory-efficient attention for decoding.
It supports page size = 1.
"""
# Adapted from
......@@ -197,7 +198,6 @@ def _decode_att_m_fwd(
logit_cap,
):
BLOCK = 32
# shape constraints
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
batch, head_num = B_req_idx.shape[0], q.shape[1]
......@@ -478,7 +478,6 @@ def _decode_grouped_att_m_fwd(
logit_cap,
):
BLOCK = 32
# shape constraints
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
if Lk == 576:
......@@ -570,9 +569,9 @@ def _decode_grouped_softmax_reducev_fwd(
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK,
BLOCK_H=BLOCK_H,
Lv=Lv,
num_warps=num_warps,
num_stages=1,
Lv=Lv,
)
......@@ -588,7 +587,7 @@ def decode_attention_fwd(
max_len_in_batch,
total_num_tokens,
sm_scale,
logit_cap=-1,
logit_cap=0.0,
att_m=None,
):
if att_m is None:
......
......@@ -61,14 +61,14 @@ def _fwd_kernel(
stride_buf_vbs,
stride_buf_vh,
stride_req_to_tokens_b,
logit_cap: tl.constexpr,
Lq: tl.constexpr,
Lv: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DPE: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
logit_cap: tl.constexpr,
Lq: tl.constexpr,
Lv: tl.constexpr,
):
cur_seq = tl.program_id(0)
cur_head = tl.program_id(1)
......@@ -111,7 +111,7 @@ def _fwd_kernel(
)
qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
# stage1: compute scores with prefix
# stage 1: compute scores with prefix
offs_n = tl.arange(0, BLOCK_N)
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
......@@ -174,7 +174,7 @@ def _fwd_kernel(
e_max = n_e_max
# stage2: compute the trianlge part
# stage 2: compute the trianlge part
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
for start_n in range(0, cur_block_m_end, BLOCK_N):
......@@ -255,26 +255,22 @@ def extend_attention_fwd(
v_buffer,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
b_seq_len_prefix,
b_start_loc_extend,
b_seq_len_extend,
max_len_in_batch,
b_start_loc_extend,
max_len_extend,
sm_scale=None,
logit_cap=-1,
logit_cap=0.0,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
"""
Lq, Lk, Lv, Lo = (
Lq, Lk, Lv = (
q_extend.shape[-1],
k_extend.shape[-1],
v_extend.shape[-1],
o_extend.shape[-1],
)
if Lq == 576:
......@@ -303,7 +299,7 @@ def extend_attention_fwd(
else:
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
sm_scale = sm_scale or 1.0 / (Lq**0.5)
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
......@@ -338,27 +334,24 @@ def extend_attention_fwd(
v_buffer.stride(0),
v_buffer.stride(1),
req_to_tokens.stride(0),
logit_cap=logit_cap,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
num_stages=num_stages,
logit_cap=logit_cap,
Lq=Lq,
Lv=Lv,
num_warps=num_warps,
num_stages=num_stages,
)
def redundant_attention(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
......
......@@ -368,7 +368,7 @@ class ScheduleBatch:
)
def batch_size(self):
return len(self.reqs) if self.reqs is not None else 0
return len(self.reqs) if self.reqs else 0
def is_empty(self):
return len(self.reqs) == 0
......
......@@ -13,15 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""Run the model with cuda graph."""
"""Run the model with cuda graph and torch.compile."""
import bisect
from contextlib import contextmanager
from typing import Callable, List
from typing import Callable
import torch
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp
......@@ -55,6 +53,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
def patch_model(
model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
):
"""Patch the model to make it compatible with with torch.compile"""
backup_ca_comm = None
try:
......@@ -86,23 +85,28 @@ def set_torch_compile_config():
class CudaGraphRunner:
def __init__(
self,
model_runner: "ModelRunner",
max_batch_size_to_capture: int,
use_torch_compile: bool,
disable_padding: bool,
):
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
def __init__(self, model_runner: "ModelRunner"):
# Parse args
self.model_runner = model_runner
self.graphs = {}
self.input_buffers = {}
self.output_buffers = {}
self.flashinfer_handlers = {}
self.graph_memory_pool = None
self.disable_padding = disable_padding
self.use_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
# Batch sizes to capture
if self.model_runner.server_args.disable_cuda_graph_padding:
self.capture_bs = list(range(1, 32)) + [64, 128]
else:
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if self.use_torch_compile else []
# Common inputs
self.max_bs = max_batch_size_to_capture
self.max_bs = max(self.capture_bs)
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
self.req_pool_indices = torch.zeros(
(self.max_bs,), dtype=torch.int32, device="cuda"
......@@ -115,56 +119,39 @@ class CudaGraphRunner:
(self.max_bs,), dtype=torch.int32, device="cuda"
)
# FlashInfer inputs
self.flashinfer_kv_indptr = torch.zeros(
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
)
self.flashinfer_kv_indices = torch.zeros(
(self.max_bs * model_runner.model_config.context_len,),
dtype=torch.int32,
device="cuda",
)
self.flashinfer_kv_last_page_len = torch.ones(
(self.max_bs,), dtype=torch.int32, device="cuda"
)
if model_runner.sliding_window_size is None:
self.flashinfer_workspace_buffer = (
self.model_runner.flashinfer_workspace_buffer
)
else:
self.flashinfer_workspace_buffer = (
self.model_runner.flashinfer_workspace_buffer
)
self.flashinfer_kv_indptr = [
self.flashinfer_kv_indptr,
self.flashinfer_kv_indptr.clone(),
]
self.flashinfer_kv_indices = [
self.flashinfer_kv_indices,
self.flashinfer_kv_indices.clone(),
]
# Attention backend
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
# Sampling inputs
# Sampling info
vocab_size = model_runner.model_config.vocab_size
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
if use_torch_compile:
if self.use_torch_compile:
set_torch_compile_config()
# Capture
try:
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n"
"Possible solutions:\n"
"1. disable cuda graph by --disable-cuda-graph\n"
"2. set --mem-fraction-static to a smaller value\n"
"3. disable torch compile by not using --enable-torch-compile\n"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
)
def can_run(self, batch_size: int):
if self.disable_padding:
return batch_size in self.graphs
else:
return batch_size <= self.max_bs
def capture(self, batch_size_list: List[int]):
self.batch_size_list = batch_size_list
def capture(self):
with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream
for bs in batch_size_list:
for bs in self.capture_bs:
with patch_model(
self.model_runner.model,
bs in self.compile_bs,
......@@ -172,14 +159,10 @@ class CudaGraphRunner:
) as forward:
(
graph,
input_buffers,
output_buffers,
flashinfer_handler,
) = self.capture_one_batch_size(bs, forward)
self.graphs[bs] = graph
self.input_buffers[bs] = input_buffers
self.output_buffers[bs] = output_buffers
self.flashinfer_handlers[bs] = flashinfer_handler
def capture_one_batch_size(self, bs: int, forward: Callable):
graph = torch.cuda.CUDAGraph()
......@@ -192,48 +175,9 @@ class CudaGraphRunner:
position_ids_offsets = self.position_ids_offsets[:bs]
out_cache_loc = self.out_cache_loc[:bs]
# FlashInfer inputs
if not _grouped_size_compiled_for_decode_kernels(
self.model_runner.model_config.num_attention_heads
// self.model_runner.tp_size,
self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
):
use_tensor_cores = True
else:
use_tensor_cores = False
if self.model_runner.sliding_window_size is None:
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer,
"NHD",
use_cuda_graph=True,
use_tensor_cores=use_tensor_cores,
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
paged_kv_indices_buffer=self.flashinfer_kv_indices,
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
)
else:
flashinfer_decode_wrapper = []
for i in range(2):
flashinfer_decode_wrapper.append(
BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer,
"NHD",
use_cuda_graph=True,
use_tensor_cores=use_tensor_cores,
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1],
paged_kv_indices_buffer=self.flashinfer_kv_indices[i],
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[
:bs
],
)
)
update_flashinfer_indices(
ForwardMode.DECODE,
self.model_runner,
req_pool_indices,
seq_lens,
None,
flashinfer_decode_wrapper,
# Attention backend
self.model_runner.attn_backend.capture_cuda_graph_init(
bs, req_pool_indices, seq_lens
)
# Run and capture
......@@ -246,13 +190,12 @@ class CudaGraphRunner:
seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
attn_backend=self.model_runner.attn_backend,
out_cache_loc=out_cache_loc,
return_logprob=False,
top_logprobs_nums=0,
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
)
return forward(input_ids, input_metadata.positions, input_metadata)
for _ in range(2):
......@@ -274,15 +217,15 @@ class CudaGraphRunner:
self.model_runner.tp_group.barrier()
self.graph_memory_pool = graph.pool()
return graph, None, out, flashinfer_decode_wrapper
return graph, out
def replay(self, batch: ScheduleBatch):
assert batch.out_cache_loc is not None
raw_bs = len(batch.reqs)
# Pad
index = bisect.bisect_left(self.batch_size_list, raw_bs)
bs = self.batch_size_list[index]
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
self.seq_lens.zero_()
self.position_ids_offsets.fill_(1)
......@@ -295,14 +238,9 @@ class CudaGraphRunner:
self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
# FlashInfer inputs
update_flashinfer_indices(
ForwardMode.DECODE,
self.model_runner,
self.req_pool_indices[:bs],
self.seq_lens[:bs],
None,
self.flashinfer_handlers[bs],
# Attention backend
self.model_runner.attn_backend.replay_cuda_graph_init(
bs, self.req_pool_indices, self.seq_lens
)
# Sampling inputs
......
......@@ -23,9 +23,8 @@ from typing import TYPE_CHECKING, List
import numpy as np
import torch
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
if TYPE_CHECKING:
from sglang.srt.layers.attention_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner
......@@ -66,12 +65,11 @@ class InputMetadata:
seq_lens: torch.Tensor
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: BaseTokenToKVPool
attn_backend: AttentionBackend
# Output location of the KV cache
out_cache_loc: torch.Tensor
total_num_tokens: int = None
# Position information
positions: torch.Tensor = None
......@@ -93,18 +91,6 @@ class InputMetadata:
image_offsets: List[List[int]] = None
modalities: List[List[str]] = None
# Trition attention backend
triton_max_seq_len: int = 0
triton_max_extend_len: int = 0
triton_start_loc: torch.Tensor = None
triton_prefix_lens: torch.Tensor = None
# FlashInfer attention backend
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
flashinfer_use_ragged: bool = False
def init_multimuldal_info(self, batch: ScheduleBatch):
reqs = batch.reqs
self.pixel_values = [r.pixel_values for r in reqs]
......@@ -154,32 +140,27 @@ class InputMetadata:
self.positions = self.positions.to(torch.int64)
def compute_extend_infos(self, batch: ScheduleBatch):
if self.forward_mode.is_decode():
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None
else:
extend_lens_cpu = [
len(r.fill_ids) - batch.prefix_lens_cpu[i]
for i, r in enumerate(batch.reqs)
]
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
self.extend_start_loc = torch.zeros_like(self.seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
self.extend_seq_lens_cpu = extend_lens_cpu
self.logprob_start_lens_cpu = [
(
min(
req.logprob_start_len - batch.prefix_lens_cpu[i],
extend_lens_cpu[i] - 1,
)
if req.logprob_start_len >= batch.prefix_lens_cpu[i]
else extend_lens_cpu[i] - 1 # Fake extend, actually decode
extend_lens_cpu = [
len(r.fill_ids) - batch.prefix_lens_cpu[i] for i, r in enumerate(batch.reqs)
]
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
self.extend_start_loc = torch.zeros_like(self.seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
self.extend_seq_lens_cpu = extend_lens_cpu
self.logprob_start_lens_cpu = [
(
min(
req.logprob_start_len - batch.prefix_lens_cpu[i],
extend_lens_cpu[i] - 1,
)
for i, req in enumerate(batch.reqs)
]
if req.logprob_start_len >= batch.prefix_lens_cpu[i]
else extend_lens_cpu[i] - 1 # Fake extend, actually decode
)
for i, req in enumerate(batch.reqs)
]
@classmethod
def from_schedule_batch(
......@@ -195,6 +176,7 @@ class InputMetadata:
seq_lens=batch.seq_lens,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
attn_backend=model_runner.attn_backend,
out_cache_loc=batch.out_cache_loc,
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
......@@ -202,76 +184,12 @@ class InputMetadata:
ret.sampling_info.update_penalties()
ret.sampling_info.update_regex_vocab_mask(batch)
ret.compute_positions(batch)
ret.compute_extend_infos(batch)
fm = batch.forward_mode
if not fm.is_decode() or model_runner.server_args.attention_backend == "triton":
ret.total_num_tokens = int(torch.sum(ret.seq_lens))
if not fm.is_decode():
if not batch.forward_mode.is_decode():
ret.init_multimuldal_info(batch)
ret.compute_extend_infos(batch)
if model_runner.server_args.attention_backend == "triton":
ret.init_triton_args(batch)
flashinfer_use_ragged = False
if model_runner.server_args.attention_backend == "flashinfer":
if (
not fm.is_decode()
and int(torch.sum(ret.seq_lens)) > 4096
and model_runner.sliding_window_size is None
):
flashinfer_use_ragged = True
ret.init_flashinfer_handlers(
model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged
)
model_runner.attn_backend.init_forward_metadata(batch, ret)
return ret
def init_triton_args(self, batch: ScheduleBatch):
"""Init auxiliary variables for triton attention backend."""
self.triton_max_seq_len = int(torch.max(self.seq_lens))
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
if self.forward_mode.is_decode():
self.triton_max_extend_len = None
else:
self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
extend_seq_lens = self.seq_lens - self.triton_prefix_lens
self.triton_max_extend_len = int(torch.max(extend_seq_lens))
def init_flashinfer_handlers(
self,
model_runner,
prefix_lens_cpu,
flashinfer_use_ragged,
):
if self.forward_mode.is_decode():
prefix_lens = None
else:
prefix_lens = self.extend_prefix_lens
update_flashinfer_indices(
self.forward_mode,
model_runner,
self.req_pool_indices,
self.seq_lens,
prefix_lens,
flashinfer_use_ragged=flashinfer_use_ragged,
)
(
self.flashinfer_prefill_wrapper_ragged,
self.flashinfer_prefill_wrapper_paged,
self.flashinfer_decode_wrapper,
self.flashinfer_use_ragged,
) = (
model_runner.flashinfer_prefill_wrapper_ragged,
model_runner.flashinfer_prefill_wrapper_paged,
model_runner.flashinfer_decode_wrapper,
flashinfer_use_ragged,
)
......@@ -25,12 +25,6 @@ from typing import Optional, Tuple, Type
import torch
import torch.nn as nn
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
from vllm.distributed import (
......@@ -43,8 +37,8 @@ from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
from sglang.global_config import global_config
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import SampleOutput
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
......@@ -69,6 +63,8 @@ logger = logging.getLogger(__name__)
class ModelRunner:
"""ModelRunner runs the forward passes of the models."""
def __init__(
self,
model_config: ModelConfig,
......@@ -100,6 +96,7 @@ class ModelRunner:
}
)
# Model-specific adjustment
if self.is_multimodal_model:
logger.info(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
......@@ -107,6 +104,7 @@ class ModelRunner:
server_args.chunked_prefill_size = None
server_args.mem_fraction_static *= 0.95
# Init componnets
min_per_gpu_memory = self.init_torch_distributed()
self.load_model()
self.init_memory_pool(
......@@ -115,7 +113,7 @@ class ModelRunner:
server_args.max_total_tokens,
)
self.init_cublas()
self.init_flashinfer()
self.init_attention_backend()
self.init_cuda_graphs()
def init_torch_distributed(self):
......@@ -397,9 +395,6 @@ class ModelRunner:
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.model_config.num_hidden_layers,
)
logger.info("using MLA Triton implementaion, flashinfer is disabled")
# FIXME: temporarily only Triton MLA is supported
self.server_args.attention_backend = "triton"
else:
self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens,
......@@ -422,106 +417,42 @@ class ModelRunner:
c = a @ b
return c
def init_flashinfer(self):
"""Init flashinfer attention kernel wrappers."""
if self.server_args.attention_backend != "flashinfer":
assert (
self.sliding_window_size is None
), "turn on flashinfer to support window attention"
self.flashinfer_prefill_wrapper_ragged = None
self.flashinfer_prefill_wrapper_paged = None
self.flashinfer_decode_wrapper = None
return
if not _grouped_size_compiled_for_decode_kernels(
self.model_config.num_attention_heads // self.tp_size,
self.model_config.get_num_kv_heads(self.tp_size),
):
use_tensor_cores = True
else:
use_tensor_cores = False
if self.sliding_window_size is None:
self.flashinfer_workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device="cuda",
)
self.flashinfer_prefill_wrapper_ragged = (
BatchPrefillWithRaggedKVCacheWrapper(
self.flashinfer_workspace_buffer, "NHD"
)
)
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer, "NHD"
)
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer,
"NHD",
use_tensor_cores=use_tensor_cores,
def init_attention_backend(self):
"""Init attention kernel backend."""
if self.server_args.attention_backend == "flashinfer":
self.attn_backend = FlashInferAttnBackend(self)
elif self.server_args.attention_backend == "triton":
assert self.sliding_window_size is None, (
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
self.attn_backend = TritonAttnBackend(self)
else:
self.flashinfer_workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device="cuda",
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"
)
self.flashinfer_prefill_wrapper_ragged = None
self.flashinfer_prefill_wrapper_paged = []
self.flashinfer_decode_wrapper = []
for i in range(2):
self.flashinfer_prefill_wrapper_paged.append(
BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer, "NHD"
)
)
self.flashinfer_decode_wrapper.append(
BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer,
"NHD",
use_tensor_cores=use_tensor_cores,
)
)
def init_cuda_graphs(self):
"""Capture cuda graphs."""
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
self.cuda_graph_runner = None
if not self.is_generation:
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
return
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
if self.server_args.disable_cuda_graph:
return
if (
self.server_args.disable_cuda_graph
or self.server_args.attention_backend != "flashinfer"
):
self.cuda_graph_runner = None
if self.server_args.attention_backend != "flashinfer":
logger.warning(
f"Cuda graph is not supported for attention backend: {self.server_args.attention_backend}"
)
return
logger.info("Capture cuda graph begin. This can take up to several minutes.")
if self.server_args.disable_cuda_graph_padding:
batch_size_list = list(range(1, 32)) + [64, 128]
else:
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
self.cuda_graph_runner = CudaGraphRunner(
self,
max_batch_size_to_capture=max(batch_size_list),
use_torch_compile=self.server_args.enable_torch_compile,
disable_padding=self.server_args.disable_cuda_graph_padding,
)
try:
self.cuda_graph_runner.capture(batch_size_list)
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n"
"Possible solutions:\n"
"1. disable cuda graph by --disable-cuda-graph\n"
"2. set --mem-fraction-static to a smaller value\n"
"3. disable torch compile by not using --enable-torch-compile\n"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
)
self.cuda_graph_runner = CudaGraphRunner(self)
@torch.inference_mode()
def forward_decode(self, batch: ScheduleBatch):
......
......@@ -143,18 +143,16 @@ class SamplingBatchInfo:
self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self, batch: ScheduleBatch):
bs, reqs = batch.batch_size(), batch.reqs
device = "cuda"
has_regex = any(req.regex_fsm is not None for req in reqs)
has_regex = any(req.regex_fsm is not None for req in batch.reqs)
# Reset the vocab mask
self.vocab_mask = None
if has_regex:
self.vocab_mask = torch.zeros(
bs, self.vocab_size, dtype=torch.bool, device=device
batch.batch_size(), self.vocab_size, dtype=torch.bool, device="cuda"
)
for i, req in enumerate(reqs):
for i, req in enumerate(batch.reqs):
if req.regex_fsm is not None:
self.vocab_mask[i].fill_(1)
self.vocab_mask[i][
......
......@@ -335,23 +335,19 @@ def launch_server(
return
# Launch processes
tokenizer_manager = TokenizerManager(server_args, port_args)
if server_args.chat_template:
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
if server_args.dp_size == 1:
start_controller_process = start_controller_process_single
else:
start_controller_process = start_controller_process_multi
proc_controller = mp.Process(
target=start_controller_process,
args=(server_args, port_args, pipe_controller_writer),
)
proc_controller.start()
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
proc_detoken = mp.Process(
target=start_detokenizer_process,
args=(
......@@ -362,6 +358,10 @@ def launch_server(
)
proc_detoken.start()
tokenizer_manager = TokenizerManager(server_args, port_args)
if server_args.chat_template:
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
# Wait for the model to finish loading
controller_init_state = pipe_controller_reader.recv()
detoken_init_state = pipe_detoken_reader.recv()
......
......@@ -83,8 +83,8 @@ class ServerArgs:
json_model_override_args: str = "{}"
# Optimization/debug options
attention_backend: str = "flashinfer"
sampling_backend: str = "flashinfer"
attention_backend: Optional[str] = None
sampling_backend: Optional[str] = None
disable_flashinfer: bool = False
disable_flashinfer_sampling: bool = False
......@@ -148,6 +148,17 @@ class ServerArgs:
)
self.sampling_backend = "pytorch"
# Default kernel backends
if self.enable_mla:
logger.info("MLA optimization is tunred on. Use triton backend.")
self.attention_backend = "triton"
if self.attention_backend is None:
self.attention_backend = "flashinfer"
if self.sampling_backend is None:
self.sampling_backend = "flashinfer"
# Model-specific patches
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
logger.info(
......
......@@ -55,8 +55,8 @@ class TestCreateKvIndices(unittest.TestCase):
paged_kernel_lens,
kv_indptr,
None,
req_to_token.size(1),
kv_indices_triton,
req_to_token.size(1),
)
# Check
......
......@@ -19,7 +19,8 @@ class TestServingThroughput(unittest.TestCase):
other_args = []
if disable_radix_cache:
other_args.append("--disable-radix-cache")
other_args.extend(["--attention-backend", attention_backend])
if attention_backend:
other_args.extend(["--attention-backend", attention_backend])
other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)])
other_args.extend(["--tensor-parallel-size", "2"])
......
......@@ -19,7 +19,8 @@ class TestServingThroughput(unittest.TestCase):
other_args = []
if disable_radix_cache:
other_args.append("--disable-radix-cache")
other_args.extend(["--attention-backend", attention_backend])
if attention_backend:
other_args.extend(["--attention-backend", attention_backend])
other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)])
model = DEFAULT_MODEL_NAME_FOR_TEST
......
......@@ -96,23 +96,17 @@ class TestExtendAttention(unittest.TestCase):
v_buffer,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
b_seq_len_prefix,
b_start_loc_extend,
b_seq_len_extend,
max_len_in_batch,
b_start_loc_extend,
max_len_extend,
)
redundant_attention(
q_extend,
k_extend,
v_extend,
o_redundant,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment