Unverified Commit 3efa7981 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Support cuda graph in the triton attention backend (#1401)

parent 2a71be5e
......@@ -36,14 +36,41 @@ class AttentionBackend(ABC):
def init_forward_metadata(
self, batch: ScheduleBatch, input_metadata: InputMetadata
):
pass
"""Init the metadata for a forward pass."""
raise NotImplementedError()
def forward(self, q, k, v, layer, input_metadata: InputMetadata):
def init_cuda_graph_state(self, max_bs: int):
"""Init the global shared states for cuda graph."""
raise NotImplementedError()
def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices, seq_lens
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise NotImplementedError()
def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices, seq_lens
):
"""Init the metadata for a forward pass for replying a cuda graph."""
raise NotImplementedError()
def get_cuda_graph_seq_len_fill_value(self):
raise NotImplementedError()
def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
"""Run forward on an attention layer."""
if input_metadata.forward_mode.is_decode():
return self.forward_decode(q, k, v, layer, input_metadata)
else:
return self.forward_extend(q, k, v, layer, input_metadata)
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
raise NotImplementedError()
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
raise NotImplementedError()
class FlashInferAttnBackend(AttentionBackend):
"""Flashinfer attention kernels."""
......@@ -153,7 +180,9 @@ class FlashInferAttnBackend(AttentionBackend):
self.cuda_graph_kv_indices.clone(),
]
def capture_cuda_graph_init(self, bs: int, req_pool_indices, seq_lens):
def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices, seq_lens
):
if self.model_runner.sliding_window_size is None:
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
......@@ -194,7 +223,9 @@ class FlashInferAttnBackend(AttentionBackend):
self.forward_metadata = (False, None, decode_wrapper)
def replay_cuda_graph_init(self, bs: int, req_pool_indices, seq_lens):
def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices, seq_lens
):
update_flashinfer_indices(
ForwardMode.DECODE,
self.model_runner,
......@@ -204,6 +235,9 @@ class FlashInferAttnBackend(AttentionBackend):
self.cuda_graph_metadata[bs],
)
def get_cuda_graph_seq_len_fill_value(self):
return 0
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
if not isinstance(self.prefill_wrapper_paged, list):
prefill_wrapper_paged = self.prefill_wrapper_paged
......@@ -290,6 +324,7 @@ class TritonAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
# Lazy import to avoid the initialization of cuda context
from sglang.srt.layers.triton_attention.decode_attention import (
REDUCE_TORCH_TYPE,
decode_attention_fwd,
)
from sglang.srt.layers.triton_attention.extend_attention import (
......@@ -300,29 +335,78 @@ class TritonAttnBackend(AttentionBackend):
self.decode_attention_fwd = decode_attention_fwd
self.extend_attention_fwd = extend_attention_fwd
self.REDUCE_TORCH_TYPE = REDUCE_TORCH_TYPE
self.num_head = model_runner.model_config.num_attention_heads
self.forward_metadata = None
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
def init_forward_metadata(
self, batch: ScheduleBatch, input_metadata: InputMetadata
):
"""Init auxiliary variables for triton attention backend."""
if input_metadata.forward_mode.is_decode():
max_seq_len = torch.max(input_metadata.seq_lens).item()
start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32)
start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0)
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
attn_logits = torch.empty(
(self.num_head, total_num_tokens),
dtype=self.REDUCE_TORCH_TYPE,
device="cuda",
)
max_seq_len = torch.max(input_metadata.seq_lens).item()
max_extend_len = None
else:
start_loc = max_seq_len = total_num_tokens = None
start_loc = attn_logits = max_seq_len = None
prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
self.forward_metadata = start_loc, max_seq_len, max_extend_len, total_num_tokens
self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
def init_cuda_graph_state(self, max_bs: int):
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
self.cuda_graph_start_loc = torch.zeros(
(max_bs,), dtype=torch.int32, device="cuda"
)
self.cuda_graph_attn_logits = torch.empty(
(self.num_head, self.cuda_graph_max_total_num_tokens),
dtype=self.REDUCE_TORCH_TYPE,
device="cuda",
)
def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices, seq_lens
):
self.forward_metadata = (
self.cuda_graph_start_loc,
self.cuda_graph_attn_logits,
self.cuda_graph_max_seq_len,
None,
)
def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices, seq_lens
):
self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
self.forward_metadata = (
self.cuda_graph_start_loc,
self.cuda_graph_attn_logits,
self.cuda_graph_max_seq_len,
None,
)
def get_cuda_graph_seq_len_fill_value(self):
return 1
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
......@@ -332,8 +416,7 @@ class TritonAttnBackend(AttentionBackend):
layer.layer_id, input_metadata.out_cache_loc, k, v
)
start_loc, max_seq_len, max_extend_len, total_num_tokens = self.forward_metadata
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k.contiguous(),
......@@ -350,16 +433,16 @@ class TritonAttnBackend(AttentionBackend):
layer.scaling,
layer.logit_cap,
)
return o
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)
start_loc, max_seq_len, max_extend_len, total_num_tokens = self.forward_metadata
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v
......@@ -374,10 +457,9 @@ class TritonAttnBackend(AttentionBackend):
input_metadata.req_pool_indices,
start_loc,
input_metadata.seq_lens,
attn_logits,
max_seq_len,
total_num_tokens,
layer.scaling,
layer.logit_cap,
)
return o
......@@ -66,18 +66,18 @@ class FlashinferUpdater:
self.head_dim = model_runner.model_config.head_dim
self.batch_size = len(req_pool_indices)
self.kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
self.decode_wrapper = (
decode_wrapper or self.model_runner.attn_backend.decode_wrapper
)
self.prefill_wrapper_ragged = (
self.model_runner.attn_backend.prefill_wrapper_ragged
)
self.prefill_wrapper_paged = (
self.model_runner.attn_backend.prefill_wrapper_paged
)
(
self.decode_wrapper,
self.prefill_wrapper_ragged,
self.prefill_wrapper_paged,
) = (
decode_wrapper or self.model_runner.attn_backend.decode_wrapper,
self.model_runner.attn_backend.prefill_wrapper_ragged,
self.model_runner.attn_backend.prefill_wrapper_paged,
self.kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)
def _init_indices_no_sliding_window(self):
......
......@@ -114,7 +114,7 @@ def _fwd_kernel_stage1(
@triton.jit
def _fwd_kernel_stage2(
Logics,
logits,
V_Buffer,
Out,
Req_to_tokens,
......@@ -162,7 +162,7 @@ def _fwd_kernel_stage2(
)
qk = tl.load(
Logics
logits
+ cur_head * stride_logic_h
+ (cur_batch_start_loc + start_n + offs_n),
mask=start_n + offs_n < cur_batch_seq_len,
......@@ -238,7 +238,7 @@ def _decode_att_m_fwd(
def _decode_softmax_reducev_fwd(
logics,
logits,
v_buffer,
o,
req_to_tokens,
......@@ -247,9 +247,9 @@ def _decode_softmax_reducev_fwd(
b_seq_len,
):
BLOCK = 64
batch, head = b_seq_len.shape[0], logics.shape[0]
batch, head = b_seq_len.shape[0], logits.shape[0]
grid = (batch, head, 1)
kv_group_num = logics.shape[0] // v_buffer.shape[1]
kv_group_num = logits.shape[0] // v_buffer.shape[1]
num_warps = 1
......@@ -257,14 +257,14 @@ def _decode_softmax_reducev_fwd(
BLOCK_DMODEL = triton.next_power_of_2(Lv)
_fwd_kernel_stage2[grid](
logics,
logits,
v_buffer,
o,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
logics.stride(0),
logits.stride(0),
v_buffer.stride(0),
v_buffer.stride(1),
o.stride(0),
......@@ -387,7 +387,7 @@ def _fwd_grouped_kernel_stage1(
@triton.jit
def _fwd_grouped_kernel_stage2(
Logics,
logits,
V_Buffer,
Out,
Req_to_tokens,
......@@ -443,7 +443,7 @@ def _fwd_grouped_kernel_stage2(
)
qk = tl.load(
Logics + offs_qk,
logits + offs_qk,
mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len),
other=float("-inf"),
)
......@@ -531,7 +531,7 @@ def _decode_grouped_att_m_fwd(
def _decode_grouped_softmax_reducev_fwd(
logics,
logits,
v_buffer,
o,
req_to_tokens,
......@@ -540,8 +540,8 @@ def _decode_grouped_softmax_reducev_fwd(
b_seq_len,
):
BLOCK = 128
batch, head_num = b_seq_len.shape[0], logics.shape[0]
kv_group_num = logics.shape[0] // v_buffer.shape[1]
batch, head_num = b_seq_len.shape[0], logits.shape[0]
kv_group_num = logits.shape[0] // v_buffer.shape[1]
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
......@@ -551,14 +551,14 @@ def _decode_grouped_softmax_reducev_fwd(
BLOCK_DMODEL = triton.next_power_of_2(Lv)
_fwd_grouped_kernel_stage2[grid](
logics,
logits,
v_buffer,
o,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
logics.stride(0),
logits.stride(0),
v_buffer.stride(0),
v_buffer.stride(1),
o.stride(0),
......@@ -584,17 +584,11 @@ def decode_attention_fwd(
b_req_idx,
b_start_loc,
b_seq_len,
attn_logits,
max_len_in_batch,
total_num_tokens,
sm_scale,
logit_cap=0.0,
att_m=None,
):
if att_m is None:
att_m = torch.empty(
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
)
kv_group_num = q.shape[1] // v_buffer.shape[1]
if kv_group_num == 1:
......@@ -602,7 +596,7 @@ def decode_attention_fwd(
_decode_att_m_fwd(
q,
k_buffer,
att_m,
attn_logits,
req_to_token,
b_req_idx,
b_start_loc,
......@@ -612,7 +606,7 @@ def decode_attention_fwd(
logit_cap,
)
_decode_softmax_reducev_fwd(
att_m,
attn_logits,
v_buffer,
o,
req_to_token,
......@@ -625,7 +619,7 @@ def decode_attention_fwd(
_decode_grouped_att_m_fwd(
q,
k_buffer,
att_m,
attn_logits,
req_to_token,
b_req_idx,
b_start_loc,
......@@ -635,7 +629,7 @@ def decode_attention_fwd(
logit_cap,
)
_decode_grouped_softmax_reducev_fwd(
att_m,
attn_logits,
v_buffer,
o,
req_to_token,
......
from __future__ import annotations
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,13 +19,12 @@ limitations under the License.
import bisect
from contextlib import contextmanager
from typing import Callable
from typing import TYPE_CHECKING, Callable
import torch
from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
from sglang.srt.layers.logits_processor import (
LogitsMetadata,
LogitsProcessor,
......@@ -35,6 +36,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import monkey_patch_vllm_all_gather
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
def _to_torch(model: torch.nn.Module, reverse: bool = False):
for sub in model._modules.values():
......@@ -111,7 +115,7 @@ class CudaGraphRunner:
self.req_pool_indices = torch.zeros(
(self.max_bs,), dtype=torch.int32, device="cuda"
)
self.seq_lens = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
self.position_ids_offsets = torch.ones(
(self.max_bs,), dtype=torch.int32, device="cuda"
)
......@@ -121,6 +125,9 @@ class CudaGraphRunner:
# Attention backend
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
self.seq_len_fill_value = (
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
)
# Sampling info
vocab_size = model_runner.model_config.vocab_size
......@@ -176,7 +183,7 @@ class CudaGraphRunner:
out_cache_loc = self.out_cache_loc[:bs]
# Attention backend
self.model_runner.attn_backend.capture_cuda_graph_init(
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
bs, req_pool_indices, seq_lens
)
......@@ -227,7 +234,7 @@ class CudaGraphRunner:
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
self.seq_lens.zero_()
self.seq_lens.fill_(self.seq_len_fill_value)
self.position_ids_offsets.fill_(1)
self.out_cache_loc.zero_()
......@@ -239,7 +246,7 @@ class CudaGraphRunner:
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
# Attention backend
self.model_runner.attn_backend.replay_cuda_graph_init(
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
bs, self.req_pool_indices, self.seq_lens
)
......
......@@ -445,12 +445,6 @@ class ModelRunner:
if self.server_args.disable_cuda_graph:
return
if self.server_args.attention_backend != "flashinfer":
logger.warning(
f"Cuda graph is not supported for attention backend: {self.server_args.attention_backend}"
)
return
logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = CudaGraphRunner(self)
......
......@@ -96,6 +96,16 @@ class TestServingThroughput(unittest.TestCase):
if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
assert res["output_throughput"] > 2400
def test_default_with_triton_attention_backend(self):
res = self.run_test(
disable_radix_cache=ServerArgs.disable_radix_cache,
attention_backend="triton",
chunked_prefill_size=-1,
)
if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
assert res["output_throughput"] > 2400
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment