"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f0c6d9784b6b5ec01e3c3a3795d22680567429aa"
Unverified Commit 3efa7981 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Support cuda graph in the triton attention backend (#1401)

parent 2a71be5e
...@@ -36,14 +36,41 @@ class AttentionBackend(ABC): ...@@ -36,14 +36,41 @@ class AttentionBackend(ABC):
def init_forward_metadata( def init_forward_metadata(
self, batch: ScheduleBatch, input_metadata: InputMetadata self, batch: ScheduleBatch, input_metadata: InputMetadata
): ):
pass """Init the metadata for a forward pass."""
raise NotImplementedError()
def forward(self, q, k, v, layer, input_metadata: InputMetadata): def init_cuda_graph_state(self, max_bs: int):
"""Init the global shared states for cuda graph."""
raise NotImplementedError()
def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices, seq_lens
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise NotImplementedError()
def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices, seq_lens
):
"""Init the metadata for a forward pass for replying a cuda graph."""
raise NotImplementedError()
def get_cuda_graph_seq_len_fill_value(self):
raise NotImplementedError()
def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
"""Run forward on an attention layer."""
if input_metadata.forward_mode.is_decode(): if input_metadata.forward_mode.is_decode():
return self.forward_decode(q, k, v, layer, input_metadata) return self.forward_decode(q, k, v, layer, input_metadata)
else: else:
return self.forward_extend(q, k, v, layer, input_metadata) return self.forward_extend(q, k, v, layer, input_metadata)
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
raise NotImplementedError()
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
raise NotImplementedError()
class FlashInferAttnBackend(AttentionBackend): class FlashInferAttnBackend(AttentionBackend):
"""Flashinfer attention kernels.""" """Flashinfer attention kernels."""
...@@ -153,7 +180,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -153,7 +180,9 @@ class FlashInferAttnBackend(AttentionBackend):
self.cuda_graph_kv_indices.clone(), self.cuda_graph_kv_indices.clone(),
] ]
def capture_cuda_graph_init(self, bs: int, req_pool_indices, seq_lens): def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices, seq_lens
):
if self.model_runner.sliding_window_size is None: if self.model_runner.sliding_window_size is None:
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, self.workspace_buffer,
...@@ -194,7 +223,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -194,7 +223,9 @@ class FlashInferAttnBackend(AttentionBackend):
self.forward_metadata = (False, None, decode_wrapper) self.forward_metadata = (False, None, decode_wrapper)
def replay_cuda_graph_init(self, bs: int, req_pool_indices, seq_lens): def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices, seq_lens
):
update_flashinfer_indices( update_flashinfer_indices(
ForwardMode.DECODE, ForwardMode.DECODE,
self.model_runner, self.model_runner,
...@@ -204,6 +235,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -204,6 +235,9 @@ class FlashInferAttnBackend(AttentionBackend):
self.cuda_graph_metadata[bs], self.cuda_graph_metadata[bs],
) )
def get_cuda_graph_seq_len_fill_value(self):
return 0
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
if not isinstance(self.prefill_wrapper_paged, list): if not isinstance(self.prefill_wrapper_paged, list):
prefill_wrapper_paged = self.prefill_wrapper_paged prefill_wrapper_paged = self.prefill_wrapper_paged
...@@ -290,6 +324,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -290,6 +324,7 @@ class TritonAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner): def __init__(self, model_runner: ModelRunner):
# Lazy import to avoid the initialization of cuda context # Lazy import to avoid the initialization of cuda context
from sglang.srt.layers.triton_attention.decode_attention import ( from sglang.srt.layers.triton_attention.decode_attention import (
REDUCE_TORCH_TYPE,
decode_attention_fwd, decode_attention_fwd,
) )
from sglang.srt.layers.triton_attention.extend_attention import ( from sglang.srt.layers.triton_attention.extend_attention import (
...@@ -300,29 +335,78 @@ class TritonAttnBackend(AttentionBackend): ...@@ -300,29 +335,78 @@ class TritonAttnBackend(AttentionBackend):
self.decode_attention_fwd = decode_attention_fwd self.decode_attention_fwd = decode_attention_fwd
self.extend_attention_fwd = extend_attention_fwd self.extend_attention_fwd = extend_attention_fwd
self.REDUCE_TORCH_TYPE = REDUCE_TORCH_TYPE
self.num_head = model_runner.model_config.num_attention_heads
self.forward_metadata = None self.forward_metadata = None
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
def init_forward_metadata( def init_forward_metadata(
self, batch: ScheduleBatch, input_metadata: InputMetadata self, batch: ScheduleBatch, input_metadata: InputMetadata
): ):
"""Init auxiliary variables for triton attention backend.""" """Init auxiliary variables for triton attention backend."""
if input_metadata.forward_mode.is_decode(): if input_metadata.forward_mode.is_decode():
max_seq_len = torch.max(input_metadata.seq_lens).item()
start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32) start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32)
start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0) start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0)
total_num_tokens = torch.sum(input_metadata.seq_lens).item() total_num_tokens = torch.sum(input_metadata.seq_lens).item()
attn_logits = torch.empty(
(self.num_head, total_num_tokens),
dtype=self.REDUCE_TORCH_TYPE,
device="cuda",
)
max_seq_len = torch.max(input_metadata.seq_lens).item()
max_extend_len = None max_extend_len = None
else: else:
start_loc = max_seq_len = total_num_tokens = None start_loc = attn_logits = max_seq_len = None
prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item() max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
self.forward_metadata = start_loc, max_seq_len, max_extend_len, total_num_tokens self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
def init_cuda_graph_state(self, max_bs: int):
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
self.cuda_graph_start_loc = torch.zeros(
(max_bs,), dtype=torch.int32, device="cuda"
)
self.cuda_graph_attn_logits = torch.empty(
(self.num_head, self.cuda_graph_max_total_num_tokens),
dtype=self.REDUCE_TORCH_TYPE,
device="cuda",
)
def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices, seq_lens
):
self.forward_metadata = (
self.cuda_graph_start_loc,
self.cuda_graph_attn_logits,
self.cuda_graph_max_seq_len,
None,
)
def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices, seq_lens
):
self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
self.forward_metadata = (
self.cuda_graph_start_loc,
self.cuda_graph_attn_logits,
self.cuda_graph_max_seq_len,
None,
)
def get_cuda_graph_seq_len_fill_value(self):
return 1
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim: if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else: else:
...@@ -332,8 +416,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -332,8 +416,7 @@ class TritonAttnBackend(AttentionBackend):
layer.layer_id, input_metadata.out_cache_loc, k, v layer.layer_id, input_metadata.out_cache_loc, k, v
) )
start_loc, max_seq_len, max_extend_len, total_num_tokens = self.forward_metadata start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
self.extend_attention_fwd( self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k.contiguous(), k.contiguous(),
...@@ -350,16 +433,16 @@ class TritonAttnBackend(AttentionBackend): ...@@ -350,16 +433,16 @@ class TritonAttnBackend(AttentionBackend):
layer.scaling, layer.scaling,
layer.logit_cap, layer.logit_cap,
) )
return o return o
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim: if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else: else:
o = torch.empty_like(q) o = torch.empty_like(q)
start_loc, max_seq_len, max_extend_len, total_num_tokens = self.forward_metadata start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
input_metadata.token_to_kv_pool.set_kv_buffer( input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v layer.layer_id, input_metadata.out_cache_loc, k, v
...@@ -374,10 +457,9 @@ class TritonAttnBackend(AttentionBackend): ...@@ -374,10 +457,9 @@ class TritonAttnBackend(AttentionBackend):
input_metadata.req_pool_indices, input_metadata.req_pool_indices,
start_loc, start_loc,
input_metadata.seq_lens, input_metadata.seq_lens,
attn_logits,
max_seq_len, max_seq_len,
total_num_tokens,
layer.scaling, layer.scaling,
layer.logit_cap, layer.logit_cap,
) )
return o return o
...@@ -66,18 +66,18 @@ class FlashinferUpdater: ...@@ -66,18 +66,18 @@ class FlashinferUpdater:
self.head_dim = model_runner.model_config.head_dim self.head_dim = model_runner.model_config.head_dim
self.batch_size = len(req_pool_indices) self.batch_size = len(req_pool_indices)
self.kv_last_page_len = torch.ones( self.decode_wrapper = (
(self.batch_size,), dtype=torch.int32, device="cuda" decode_wrapper or self.model_runner.attn_backend.decode_wrapper
)
self.prefill_wrapper_ragged = (
self.model_runner.attn_backend.prefill_wrapper_ragged
)
self.prefill_wrapper_paged = (
self.model_runner.attn_backend.prefill_wrapper_paged
) )
( self.kv_last_page_len = torch.ones(
self.decode_wrapper, (self.batch_size,), dtype=torch.int32, device="cuda"
self.prefill_wrapper_ragged,
self.prefill_wrapper_paged,
) = (
decode_wrapper or self.model_runner.attn_backend.decode_wrapper,
self.model_runner.attn_backend.prefill_wrapper_ragged,
self.model_runner.attn_backend.prefill_wrapper_paged,
) )
def _init_indices_no_sliding_window(self): def _init_indices_no_sliding_window(self):
......
...@@ -114,7 +114,7 @@ def _fwd_kernel_stage1( ...@@ -114,7 +114,7 @@ def _fwd_kernel_stage1(
@triton.jit @triton.jit
def _fwd_kernel_stage2( def _fwd_kernel_stage2(
Logics, logits,
V_Buffer, V_Buffer,
Out, Out,
Req_to_tokens, Req_to_tokens,
...@@ -162,7 +162,7 @@ def _fwd_kernel_stage2( ...@@ -162,7 +162,7 @@ def _fwd_kernel_stage2(
) )
qk = tl.load( qk = tl.load(
Logics logits
+ cur_head * stride_logic_h + cur_head * stride_logic_h
+ (cur_batch_start_loc + start_n + offs_n), + (cur_batch_start_loc + start_n + offs_n),
mask=start_n + offs_n < cur_batch_seq_len, mask=start_n + offs_n < cur_batch_seq_len,
...@@ -238,7 +238,7 @@ def _decode_att_m_fwd( ...@@ -238,7 +238,7 @@ def _decode_att_m_fwd(
def _decode_softmax_reducev_fwd( def _decode_softmax_reducev_fwd(
logics, logits,
v_buffer, v_buffer,
o, o,
req_to_tokens, req_to_tokens,
...@@ -247,9 +247,9 @@ def _decode_softmax_reducev_fwd( ...@@ -247,9 +247,9 @@ def _decode_softmax_reducev_fwd(
b_seq_len, b_seq_len,
): ):
BLOCK = 64 BLOCK = 64
batch, head = b_seq_len.shape[0], logics.shape[0] batch, head = b_seq_len.shape[0], logits.shape[0]
grid = (batch, head, 1) grid = (batch, head, 1)
kv_group_num = logics.shape[0] // v_buffer.shape[1] kv_group_num = logits.shape[0] // v_buffer.shape[1]
num_warps = 1 num_warps = 1
...@@ -257,14 +257,14 @@ def _decode_softmax_reducev_fwd( ...@@ -257,14 +257,14 @@ def _decode_softmax_reducev_fwd(
BLOCK_DMODEL = triton.next_power_of_2(Lv) BLOCK_DMODEL = triton.next_power_of_2(Lv)
_fwd_kernel_stage2[grid]( _fwd_kernel_stage2[grid](
logics, logits,
v_buffer, v_buffer,
o, o,
req_to_tokens, req_to_tokens,
b_req_idx, b_req_idx,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
logics.stride(0), logits.stride(0),
v_buffer.stride(0), v_buffer.stride(0),
v_buffer.stride(1), v_buffer.stride(1),
o.stride(0), o.stride(0),
...@@ -387,7 +387,7 @@ def _fwd_grouped_kernel_stage1( ...@@ -387,7 +387,7 @@ def _fwd_grouped_kernel_stage1(
@triton.jit @triton.jit
def _fwd_grouped_kernel_stage2( def _fwd_grouped_kernel_stage2(
Logics, logits,
V_Buffer, V_Buffer,
Out, Out,
Req_to_tokens, Req_to_tokens,
...@@ -443,7 +443,7 @@ def _fwd_grouped_kernel_stage2( ...@@ -443,7 +443,7 @@ def _fwd_grouped_kernel_stage2(
) )
qk = tl.load( qk = tl.load(
Logics + offs_qk, logits + offs_qk,
mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len), mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len),
other=float("-inf"), other=float("-inf"),
) )
...@@ -531,7 +531,7 @@ def _decode_grouped_att_m_fwd( ...@@ -531,7 +531,7 @@ def _decode_grouped_att_m_fwd(
def _decode_grouped_softmax_reducev_fwd( def _decode_grouped_softmax_reducev_fwd(
logics, logits,
v_buffer, v_buffer,
o, o,
req_to_tokens, req_to_tokens,
...@@ -540,8 +540,8 @@ def _decode_grouped_softmax_reducev_fwd( ...@@ -540,8 +540,8 @@ def _decode_grouped_softmax_reducev_fwd(
b_seq_len, b_seq_len,
): ):
BLOCK = 128 BLOCK = 128
batch, head_num = b_seq_len.shape[0], logics.shape[0] batch, head_num = b_seq_len.shape[0], logits.shape[0]
kv_group_num = logics.shape[0] // v_buffer.shape[1] kv_group_num = logits.shape[0] // v_buffer.shape[1]
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num)) BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1) grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
...@@ -551,14 +551,14 @@ def _decode_grouped_softmax_reducev_fwd( ...@@ -551,14 +551,14 @@ def _decode_grouped_softmax_reducev_fwd(
BLOCK_DMODEL = triton.next_power_of_2(Lv) BLOCK_DMODEL = triton.next_power_of_2(Lv)
_fwd_grouped_kernel_stage2[grid]( _fwd_grouped_kernel_stage2[grid](
logics, logits,
v_buffer, v_buffer,
o, o,
req_to_tokens, req_to_tokens,
b_req_idx, b_req_idx,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
logics.stride(0), logits.stride(0),
v_buffer.stride(0), v_buffer.stride(0),
v_buffer.stride(1), v_buffer.stride(1),
o.stride(0), o.stride(0),
...@@ -584,17 +584,11 @@ def decode_attention_fwd( ...@@ -584,17 +584,11 @@ def decode_attention_fwd(
b_req_idx, b_req_idx,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
attn_logits,
max_len_in_batch, max_len_in_batch,
total_num_tokens,
sm_scale, sm_scale,
logit_cap=0.0, logit_cap=0.0,
att_m=None,
): ):
if att_m is None:
att_m = torch.empty(
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
)
kv_group_num = q.shape[1] // v_buffer.shape[1] kv_group_num = q.shape[1] // v_buffer.shape[1]
if kv_group_num == 1: if kv_group_num == 1:
...@@ -602,7 +596,7 @@ def decode_attention_fwd( ...@@ -602,7 +596,7 @@ def decode_attention_fwd(
_decode_att_m_fwd( _decode_att_m_fwd(
q, q,
k_buffer, k_buffer,
att_m, attn_logits,
req_to_token, req_to_token,
b_req_idx, b_req_idx,
b_start_loc, b_start_loc,
...@@ -612,7 +606,7 @@ def decode_attention_fwd( ...@@ -612,7 +606,7 @@ def decode_attention_fwd(
logit_cap, logit_cap,
) )
_decode_softmax_reducev_fwd( _decode_softmax_reducev_fwd(
att_m, attn_logits,
v_buffer, v_buffer,
o, o,
req_to_token, req_to_token,
...@@ -625,7 +619,7 @@ def decode_attention_fwd( ...@@ -625,7 +619,7 @@ def decode_attention_fwd(
_decode_grouped_att_m_fwd( _decode_grouped_att_m_fwd(
q, q,
k_buffer, k_buffer,
att_m, attn_logits,
req_to_token, req_to_token,
b_req_idx, b_req_idx,
b_start_loc, b_start_loc,
...@@ -635,7 +629,7 @@ def decode_attention_fwd( ...@@ -635,7 +629,7 @@ def decode_attention_fwd(
logit_cap, logit_cap,
) )
_decode_grouped_softmax_reducev_fwd( _decode_grouped_softmax_reducev_fwd(
att_m, attn_logits,
v_buffer, v_buffer,
o, o,
req_to_token, req_to_token,
......
from __future__ import annotations
""" """
Copyright 2023-2024 SGLang Team Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
...@@ -17,13 +19,12 @@ limitations under the License. ...@@ -17,13 +19,12 @@ limitations under the License.
import bisect import bisect
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable from typing import TYPE_CHECKING, Callable
import torch import torch
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
from sglang.srt.layers.logits_processor import ( from sglang.srt.layers.logits_processor import (
LogitsMetadata, LogitsMetadata,
LogitsProcessor, LogitsProcessor,
...@@ -35,6 +36,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad ...@@ -35,6 +36,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import monkey_patch_vllm_all_gather from sglang.srt.utils import monkey_patch_vllm_all_gather
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
def _to_torch(model: torch.nn.Module, reverse: bool = False): def _to_torch(model: torch.nn.Module, reverse: bool = False):
for sub in model._modules.values(): for sub in model._modules.values():
...@@ -111,7 +115,7 @@ class CudaGraphRunner: ...@@ -111,7 +115,7 @@ class CudaGraphRunner:
self.req_pool_indices = torch.zeros( self.req_pool_indices = torch.zeros(
(self.max_bs,), dtype=torch.int32, device="cuda" (self.max_bs,), dtype=torch.int32, device="cuda"
) )
self.seq_lens = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda") self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
self.position_ids_offsets = torch.ones( self.position_ids_offsets = torch.ones(
(self.max_bs,), dtype=torch.int32, device="cuda" (self.max_bs,), dtype=torch.int32, device="cuda"
) )
...@@ -121,6 +125,9 @@ class CudaGraphRunner: ...@@ -121,6 +125,9 @@ class CudaGraphRunner:
# Attention backend # Attention backend
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs) self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
self.seq_len_fill_value = (
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
)
# Sampling info # Sampling info
vocab_size = model_runner.model_config.vocab_size vocab_size = model_runner.model_config.vocab_size
...@@ -176,7 +183,7 @@ class CudaGraphRunner: ...@@ -176,7 +183,7 @@ class CudaGraphRunner:
out_cache_loc = self.out_cache_loc[:bs] out_cache_loc = self.out_cache_loc[:bs]
# Attention backend # Attention backend
self.model_runner.attn_backend.capture_cuda_graph_init( self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
bs, req_pool_indices, seq_lens bs, req_pool_indices, seq_lens
) )
...@@ -227,7 +234,7 @@ class CudaGraphRunner: ...@@ -227,7 +234,7 @@ class CudaGraphRunner:
index = bisect.bisect_left(self.capture_bs, raw_bs) index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index] bs = self.capture_bs[index]
if bs != raw_bs: if bs != raw_bs:
self.seq_lens.zero_() self.seq_lens.fill_(self.seq_len_fill_value)
self.position_ids_offsets.fill_(1) self.position_ids_offsets.fill_(1)
self.out_cache_loc.zero_() self.out_cache_loc.zero_()
...@@ -239,7 +246,7 @@ class CudaGraphRunner: ...@@ -239,7 +246,7 @@ class CudaGraphRunner:
self.out_cache_loc[:raw_bs] = batch.out_cache_loc self.out_cache_loc[:raw_bs] = batch.out_cache_loc
# Attention backend # Attention backend
self.model_runner.attn_backend.replay_cuda_graph_init( self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
bs, self.req_pool_indices, self.seq_lens bs, self.req_pool_indices, self.seq_lens
) )
......
...@@ -445,12 +445,6 @@ class ModelRunner: ...@@ -445,12 +445,6 @@ class ModelRunner:
if self.server_args.disable_cuda_graph: if self.server_args.disable_cuda_graph:
return return
if self.server_args.attention_backend != "flashinfer":
logger.warning(
f"Cuda graph is not supported for attention backend: {self.server_args.attention_backend}"
)
return
logger.info("Capture cuda graph begin. This can take up to several minutes.") logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = CudaGraphRunner(self) self.cuda_graph_runner = CudaGraphRunner(self)
......
...@@ -96,6 +96,16 @@ class TestServingThroughput(unittest.TestCase): ...@@ -96,6 +96,16 @@ class TestServingThroughput(unittest.TestCase):
if os.getenv("SGLANG_IS_IN_CI", "false") == "true": if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
assert res["output_throughput"] > 2400 assert res["output_throughput"] > 2400
def test_default_with_triton_attention_backend(self):
res = self.run_test(
disable_radix_cache=ServerArgs.disable_radix_cache,
attention_backend="triton",
chunked_prefill_size=-1,
)
if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
assert res["output_throughput"] > 2400
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment