Unverified Commit 9e93ef3f authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[fix] fix illegal mem access and clean up triton attention backend (#4571)

parent fad86a68
...@@ -39,7 +39,6 @@ class AttentionBackend(ABC): ...@@ -39,7 +39,6 @@ class AttentionBackend(ABC):
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
self, self,
bs: int, bs: int,
num_kv_heads: int,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
......
...@@ -349,7 +349,6 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -349,7 +349,6 @@ class FlashInferAttnBackend(AttentionBackend):
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
self, self,
bs: int, bs: int,
num_kv_heads: int,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
...@@ -1063,7 +1062,6 @@ class FlashInferMultiStepDraftBackend: ...@@ -1063,7 +1062,6 @@ class FlashInferMultiStepDraftBackend:
def call_fn(i, forward_batch): def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph( self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs, bs,
-1,
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
forward_batch.seq_lens, forward_batch.seq_lens,
seq_lens_sum=-1, seq_lens_sum=-1,
......
...@@ -279,7 +279,6 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -279,7 +279,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
self, self,
bs: int, bs: int,
num_kv_heads: int,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
...@@ -792,7 +791,6 @@ class FlashInferMLAMultiStepDraftBackend: ...@@ -792,7 +791,6 @@ class FlashInferMLAMultiStepDraftBackend:
def call_fn(i, forward_batch): def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph( self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs, bs,
-1,
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
forward_batch.seq_lens, forward_batch.seq_lens,
seq_lens_sum=-1, seq_lens_sum=-1,
......
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
import torch import torch
...@@ -22,20 +23,21 @@ if TYPE_CHECKING: ...@@ -22,20 +23,21 @@ if TYPE_CHECKING:
def get_num_kv_splits_triton( def get_num_kv_splits_triton(
num_kv_splits_ptr, num_kv_splits_ptr,
seq_lens_ptr, seq_lens_ptr,
bs, num_seq,
num_group,
num_head, num_head,
num_kv_head, num_kv_head,
max_kv_splits, max_kv_splits,
device_core_count, device_core_count,
MAX_BS: tl.constexpr, MAX_NUM_SEQ: tl.constexpr,
): ):
# TODO: this method is tunable # TODO: this method is tunable, we need more online serving data to tune it
offs_b = tl.arange(0, MAX_BS) offs_seq = tl.arange(0, MAX_NUM_SEQ)
mask_b = offs_b < bs mask_seq = offs_seq < num_seq
seq_lens = tl.load(seq_lens_ptr + offs_b, mask=mask_b, other=0) seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)
max_seq_len = tl.max(seq_lens) max_seq_len = tl.max(seq_lens)
seq_lens = tl.load(seq_lens_ptr + offs_b, mask=mask_b, other=max_seq_len) seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)
min_seq_len = tl.min(seq_lens) min_seq_len = tl.min(seq_lens)
if max_seq_len * 8 < min_seq_len * 10: if max_seq_len * 8 < min_seq_len * 10:
min_seq_len = max_seq_len min_seq_len = max_seq_len
...@@ -43,24 +45,43 @@ def get_num_kv_splits_triton( ...@@ -43,24 +45,43 @@ def get_num_kv_splits_triton(
kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1) kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)
# NOTE: this is a hack to let num_kv_split grows up with seqlen gradually # NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
ext_seq_len = tl.cast(tl.cdiv(max_seq_len, 256), tl.float32) ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0
ext_device_core_count = device_core_count * tl.maximum( ext_device_core_count = tl.cast(
tl.cast(tl.ceil(tl.log2(ext_seq_len)), tl.int32), 1 device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32
) )
block_h, num_kv_group = 16, num_head // num_kv_head block_h, num_kv_group = 16, num_head // num_kv_head
if num_kv_group == 1: if num_kv_group == 1:
bh_grid = bs * num_head token_grid = num_seq * num_group * num_head
else: else:
# from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd # from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
block_h = tl.minimum(block_h, num_kv_group) block_h = tl.minimum(block_h, num_kv_group)
bh_grid = bs * tl.cdiv(num_head, block_h) token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)
max_kv_splits_2 = tl.minimum(tl.cdiv(ext_device_core_count, bh_grid), max_kv_splits) max_kv_splits_2 = tl.minimum(
tl.cdiv(ext_device_core_count, token_grid), max_kv_splits
)
kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2) kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)
num_kv_splits = tl.maximum( num_kv_splits = tl.maximum(
tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2) tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
) )
tl.store(num_kv_splits_ptr + offs_b, num_kv_splits, mask=mask_b)
offs_token = offs_seq * num_group
mask_token = offs_token < num_seq * num_group
for i in range(0, num_group):
tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)
@dataclass
class ForwardMetadata:
attn_logits: torch.Tensor
attn_lse: torch.Tensor
max_extend_len: int
num_kv_splits: torch.Tensor
kv_indptr: torch.Tensor
kv_indices: torch.Tensor
qo_indptr: torch.Tensor
custom_mask: torch.Tensor
mask_indptr: torch.Tensor
class TritonAttnBackend(AttentionBackend): class TritonAttnBackend(AttentionBackend):
...@@ -110,6 +131,9 @@ class TritonAttnBackend(AttentionBackend): ...@@ -110,6 +131,9 @@ class TritonAttnBackend(AttentionBackend):
self.num_head = ( self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size() model_runner.model_config.num_attention_heads // get_attention_tp_size()
) )
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
self.static_kv_splits = get_bool_env_var( self.static_kv_splits = get_bool_env_var(
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false" "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
...@@ -117,7 +141,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -117,7 +141,7 @@ class TritonAttnBackend(AttentionBackend):
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
self.forward_metadata = None self.forward_metadata: ForwardMetadata = None
self.max_context_len = model_runner.model_config.context_len self.max_context_len = model_runner.model_config.context_len
...@@ -128,23 +152,33 @@ class TritonAttnBackend(AttentionBackend): ...@@ -128,23 +152,33 @@ class TritonAttnBackend(AttentionBackend):
self, self,
num_kv_splits: torch.Tensor, num_kv_splits: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
bs: int,
num_kv_head: int,
): ):
MAX_SCHEDULE_BS = 4096 num_token, num_seq = num_kv_splits.shape[0], seq_lens.shape[0]
if self.static_kv_splits or self.device_core_count <= 0 or bs > MAX_SCHEDULE_BS: num_group = num_token // num_seq
assert (
num_group * num_seq == num_token
), f"num_seq({num_seq}), num_token({num_token}), something goes wrong!"
if self.static_kv_splits or self.device_core_count <= 0:
num_kv_splits.fill_(self.max_kv_splits) num_kv_splits.fill_(self.max_kv_splits)
return return
if num_seq < 256:
SCHEDULE_SEQ = 256
else:
SCHEDULE_SEQ = triton.next_power_of_2(num_seq)
get_num_kv_splits_triton[(1,)]( get_num_kv_splits_triton[(1,)](
num_kv_splits, num_kv_splits,
seq_lens, seq_lens,
bs, num_seq,
num_group,
self.num_head, self.num_head,
num_kv_head, self.num_kv_head,
self.max_kv_splits, self.max_kv_splits,
self.device_core_count, self.device_core_count,
MAX_BS=MAX_SCHEDULE_BS, MAX_NUM_SEQ=SCHEDULE_SEQ,
) )
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
...@@ -174,36 +208,19 @@ class TritonAttnBackend(AttentionBackend): ...@@ -174,36 +208,19 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1 bs = kv_indptr.shape[0] - 1
attn_logits = [ attn_logits = torch.empty(
torch.empty( (bs, self.num_head, self.max_kv_splits, self.v_head_dim),
( dtype=torch.float32,
bs, device=self.device,
self.num_head, )
self.max_kv_splits, attn_lse = torch.empty(
self.v_head_dim, (bs, self.num_head, self.max_kv_splits),
), dtype=torch.float32,
dtype=torch.float32, device=self.device,
device=self.device, )
),
torch.empty(
(
bs,
self.num_head,
self.max_kv_splits,
),
dtype=torch.float32,
device=self.device,
),
]
num_kv_splits = torch.empty((bs,), dtype=torch.int32, device=self.device) num_kv_splits = torch.empty((bs,), dtype=torch.int32, device=self.device)
num_kv_heads = self.num_head self.get_num_kv_splits(num_kv_splits, forward_batch.seq_lens)
if hasattr(forward_batch.token_to_kv_pool, "k_buffer"):
if isinstance(forward_batch.token_to_kv_pool.k_buffer, list):
num_kv_heads = forward_batch.token_to_kv_pool.k_buffer[0].shape[1]
self.get_num_kv_splits(
num_kv_splits, forward_batch.seq_lens, bs, num_kv_heads
)
qo_indptr = None qo_indptr = None
custom_mask = None custom_mask = None
...@@ -244,6 +261,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -244,6 +261,7 @@ class TritonAttnBackend(AttentionBackend):
max_extend_len = self.num_draft_tokens max_extend_len = self.num_draft_tokens
num_kv_splits = None num_kv_splits = None
attn_logits = None attn_logits = None
attn_lse = None
elif forward_batch.forward_mode.is_draft_extend(): elif forward_batch.forward_mode.is_draft_extend():
kv_indices, kv_indptr, qo_indptr, custom_mask = ( kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill( spec_info.generate_attn_arg_prefill(
...@@ -254,9 +272,13 @@ class TritonAttnBackend(AttentionBackend): ...@@ -254,9 +272,13 @@ class TritonAttnBackend(AttentionBackend):
) )
) )
mask_indptr = None mask_indptr = None
# TODO(FIXME): This will trigger an invalid Eagle tree when using
# `max(spec_info.accept_length_cpu)`.
# It might have been forgotten to update somewhere.
max_extend_len = torch.max(spec_info.accept_length).item() max_extend_len = torch.max(spec_info.accept_length).item()
num_kv_splits = None num_kv_splits = None
attn_logits = None attn_logits = None
attn_lse = None
else: else:
kv_indptr[1 : bs + 1] = torch.cumsum( kv_indptr[1 : bs + 1] = torch.cumsum(
forward_batch.extend_prefix_lens, dim=0 forward_batch.extend_prefix_lens, dim=0
...@@ -283,11 +305,13 @@ class TritonAttnBackend(AttentionBackend): ...@@ -283,11 +305,13 @@ class TritonAttnBackend(AttentionBackend):
custom_mask = None custom_mask = None
mask_indptr = None mask_indptr = None
attn_logits = None attn_logits = None
attn_lse = None
max_extend_len = torch.max(forward_batch.extend_seq_lens).item() max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
num_kv_splits = None num_kv_splits = None
self.forward_metadata = ( self.forward_metadata = ForwardMetadata(
attn_logits, attn_logits,
attn_lse,
max_extend_len, max_extend_len,
num_kv_splits, num_kv_splits,
kv_indptr, kv_indptr,
...@@ -300,18 +324,16 @@ class TritonAttnBackend(AttentionBackend): ...@@ -300,18 +324,16 @@ class TritonAttnBackend(AttentionBackend):
def init_cuda_graph_state( def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
): ):
self.cuda_graph_attn_logits = [ self.cuda_graph_attn_logits = torch.zeros(
torch.zeros( (max_bs, self.num_head, self.max_kv_splits, self.v_head_dim),
(max_bs, self.num_head, self.max_kv_splits, self.v_head_dim), dtype=torch.float32,
dtype=torch.float32, device=self.device,
device=self.device, )
), self.cuda_graph_attn_lse = torch.zeros(
torch.zeros( (max_bs, self.num_head, self.max_kv_splits),
(max_bs, self.num_head, self.max_kv_splits), dtype=torch.float32,
dtype=torch.float32, device=self.device,
device=self.device, )
),
]
self.cuda_graph_num_kv_splits = torch.full( self.cuda_graph_num_kv_splits = torch.full(
(max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device (max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
) )
...@@ -362,6 +384,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -362,6 +384,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
attn_logits = self.cuda_graph_attn_logits attn_logits = self.cuda_graph_attn_logits
attn_lse = self.cuda_graph_attn_lse
max_extend_len = None max_extend_len = None
num_kv_splits = self.cuda_graph_num_kv_splits num_kv_splits = self.cuda_graph_num_kv_splits
qo_indptr = None qo_indptr = None
...@@ -396,13 +419,15 @@ class TritonAttnBackend(AttentionBackend): ...@@ -396,13 +419,15 @@ class TritonAttnBackend(AttentionBackend):
max_extend_len = self.num_draft_tokens max_extend_len = self.num_draft_tokens
num_kv_splits = None num_kv_splits = None
attn_logits = None attn_logits = None
attn_lse = None
else: else:
raise ValueError( raise ValueError(
f"Invalid forward mode: {forward_mode=} for CUDA Graph capture." f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
) )
self.forward_metadata = ( self.forward_metadata = ForwardMetadata(
attn_logits, attn_logits,
attn_lse,
max_extend_len, max_extend_len,
num_kv_splits, num_kv_splits,
kv_indptr, kv_indptr,
...@@ -415,7 +440,6 @@ class TritonAttnBackend(AttentionBackend): ...@@ -415,7 +440,6 @@ class TritonAttnBackend(AttentionBackend):
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
self, self,
bs: int, bs: int,
num_kv_head: int,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
...@@ -442,10 +466,12 @@ class TritonAttnBackend(AttentionBackend): ...@@ -442,10 +466,12 @@ class TritonAttnBackend(AttentionBackend):
kv_indices, kv_indices,
self.req_to_token.stride(0), self.req_to_token.stride(0),
) )
num_token = bs
else: else:
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
self.get_num_kv_splits(num_kv_splits, seq_lens, bs, num_kv_head) num_token = spec_info.kv_indptr.shape[0] - 1
self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs])
elif forward_mode.is_target_verify(): elif forward_mode.is_target_verify():
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr # Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
bs = len(req_pool_indices) bs = len(req_pool_indices)
...@@ -502,17 +528,6 @@ class TritonAttnBackend(AttentionBackend): ...@@ -502,17 +528,6 @@ class TritonAttnBackend(AttentionBackend):
layer, forward_batch.out_cache_loc, k, v layer, forward_batch.out_cache_loc, k, v
) )
(
_,
max_extend_len,
_,
kv_indptr,
kv_indices,
qo_indptr,
custom_mask,
mask_indptr,
) = self.forward_metadata
self.extend_attention_fwd( self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k.contiguous(), k.contiguous(),
...@@ -520,12 +535,12 @@ class TritonAttnBackend(AttentionBackend): ...@@ -520,12 +535,12 @@ class TritonAttnBackend(AttentionBackend):
o.view(-1, layer.tp_q_head_num, layer.v_head_dim), o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
qo_indptr, self.forward_metadata.qo_indptr,
kv_indptr, self.forward_metadata.kv_indptr,
kv_indices, self.forward_metadata.kv_indices,
custom_mask, self.forward_metadata.custom_mask,
mask_indptr, self.forward_metadata.mask_indptr,
max_extend_len, self.forward_metadata.max_extend_len,
layer.scaling, layer.scaling,
layer.logit_cap, layer.logit_cap,
) )
...@@ -550,10 +565,6 @@ class TritonAttnBackend(AttentionBackend): ...@@ -550,10 +565,6 @@ class TritonAttnBackend(AttentionBackend):
else: else:
o = torch.empty_like(q) o = torch.empty_like(q)
attn_logits, _, num_kv_splits, kv_indptr, kv_indices, _, _, _ = (
self.forward_metadata
)
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v layer, forward_batch.out_cache_loc, k, v
...@@ -564,10 +575,11 @@ class TritonAttnBackend(AttentionBackend): ...@@ -564,10 +575,11 @@ class TritonAttnBackend(AttentionBackend):
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim), o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
kv_indptr, self.forward_metadata.kv_indptr,
kv_indices, self.forward_metadata.kv_indices,
attn_logits, self.forward_metadata.attn_logits,
num_kv_splits, self.forward_metadata.attn_lse,
self.forward_metadata.num_kv_splits,
self.max_kv_splits, self.max_kv_splits,
layer.scaling, layer.scaling,
layer.logit_cap, layer.logit_cap,
...@@ -700,15 +712,9 @@ class TritonMultiStepDraftBackend: ...@@ -700,15 +712,9 @@ class TritonMultiStepDraftBackend:
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int self, forward_batch: ForwardBatch, bs: int
): ):
num_kv_heads = self.num_head
if hasattr(forward_batch.token_to_kv_pool, "k_buffer"):
if isinstance(forward_batch.token_to_kv_pool.k_buffer, list):
num_kv_heads = forward_batch.token_to_kv_pool.k_buffer[0].shape[1]
def call_fn(i, forward_batch): def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph( self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs, bs,
num_kv_heads,
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
forward_batch.seq_lens, forward_batch.seq_lens,
seq_lens_sum=-1, seq_lens_sum=-1,
......
...@@ -609,6 +609,7 @@ def decode_attention_fwd_normal( ...@@ -609,6 +609,7 @@ def decode_attention_fwd_normal(
kv_indptr, kv_indptr,
kv_indices, kv_indices,
attn_logits, attn_logits,
attn_lse,
num_kv_splits, num_kv_splits,
max_kv_splits, max_kv_splits,
sm_scale, sm_scale,
...@@ -618,8 +619,8 @@ def decode_attention_fwd_normal( ...@@ -618,8 +619,8 @@ def decode_attention_fwd_normal(
q, q,
k_buffer, k_buffer,
v_buffer, v_buffer,
attn_logits[0], attn_logits,
attn_logits[1], attn_lse,
kv_indptr, kv_indptr,
kv_indices, kv_indices,
num_kv_splits, num_kv_splits,
...@@ -628,8 +629,8 @@ def decode_attention_fwd_normal( ...@@ -628,8 +629,8 @@ def decode_attention_fwd_normal(
logit_cap, logit_cap,
) )
_decode_softmax_reducev_fwd( _decode_softmax_reducev_fwd(
attn_logits[0], attn_logits,
attn_logits[1], attn_lse,
q, q,
o, o,
v_buffer, v_buffer,
...@@ -647,6 +648,7 @@ def decode_attention_fwd_grouped( ...@@ -647,6 +648,7 @@ def decode_attention_fwd_grouped(
kv_indptr, kv_indptr,
kv_indices, kv_indices,
attn_logits, attn_logits,
attn_lse,
num_kv_splits, num_kv_splits,
max_kv_splits, max_kv_splits,
sm_scale, sm_scale,
...@@ -656,8 +658,8 @@ def decode_attention_fwd_grouped( ...@@ -656,8 +658,8 @@ def decode_attention_fwd_grouped(
q, q,
k_buffer, k_buffer,
v_buffer, v_buffer,
attn_logits[0], attn_logits,
attn_logits[1], attn_lse,
kv_indptr, kv_indptr,
kv_indices, kv_indices,
num_kv_splits, num_kv_splits,
...@@ -666,8 +668,8 @@ def decode_attention_fwd_grouped( ...@@ -666,8 +668,8 @@ def decode_attention_fwd_grouped(
logit_cap, logit_cap,
) )
_decode_softmax_reducev_fwd( _decode_softmax_reducev_fwd(
attn_logits[0], attn_logits,
attn_logits[1], attn_lse,
q, q,
o, o,
v_buffer, v_buffer,
...@@ -685,14 +687,15 @@ def decode_attention_fwd( ...@@ -685,14 +687,15 @@ def decode_attention_fwd(
kv_indptr, kv_indptr,
kv_indices, kv_indices,
attn_logits, attn_logits,
attn_lse,
num_kv_splits, num_kv_splits,
max_kv_splits, max_kv_splits,
sm_scale, sm_scale,
logit_cap=0.0, logit_cap=0.0,
): ):
assert max_kv_splits == attn_logits[0].shape[2] assert max_kv_splits == attn_logits.shape[2]
assert q.shape[0] <= kv_indptr.shape[0] - 1 assert q.shape[0] <= kv_indptr.shape[0] - 1
assert q.shape[0] <= attn_logits[0].shape[0] assert q.shape[0] <= attn_logits.shape[0]
kv_group_num = q.shape[1] // v_buffer.shape[1] kv_group_num = q.shape[1] // v_buffer.shape[1]
...@@ -706,6 +709,7 @@ def decode_attention_fwd( ...@@ -706,6 +709,7 @@ def decode_attention_fwd(
kv_indptr, kv_indptr,
kv_indices, kv_indices,
attn_logits, attn_logits,
attn_lse,
num_kv_splits, num_kv_splits,
max_kv_splits, max_kv_splits,
sm_scale, sm_scale,
...@@ -721,6 +725,7 @@ def decode_attention_fwd( ...@@ -721,6 +725,7 @@ def decode_attention_fwd(
kv_indptr, kv_indptr,
kv_indices, kv_indices,
attn_logits, attn_logits,
attn_lse,
num_kv_splits, num_kv_splits,
max_kv_splits, max_kv_splits,
sm_scale, sm_scale,
......
...@@ -26,7 +26,6 @@ import tqdm ...@@ -26,7 +26,6 @@ import tqdm
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
from sglang.srt.layers.torchao_utils import save_gemlite_cache from sglang.srt.layers.torchao_utils import save_gemlite_cache
...@@ -196,9 +195,6 @@ class CudaGraphRunner: ...@@ -196,9 +195,6 @@ class CudaGraphRunner:
# Attention backend # Attention backend
self.max_bs = max(self.capture_bs) self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs self.max_num_token = self.max_bs * self.num_tokens_per_bs
self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token) self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
self.seq_len_fill_value = ( self.seq_len_fill_value = (
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
...@@ -507,15 +503,9 @@ class CudaGraphRunner: ...@@ -507,15 +503,9 @@ class CudaGraphRunner:
if hasattr(forward_batch.spec_info, "hidden_states"): if hasattr(forward_batch.spec_info, "hidden_states"):
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
num_kv_heads = self.num_head
if hasattr(forward_batch.token_to_kv_pool, "k_buffer"):
if isinstance(forward_batch.token_to_kv_pool.k_buffer, list):
num_kv_heads = forward_batch.token_to_kv_pool.k_buffer[0].shape[1]
# Attention backend # Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
bs, bs,
num_kv_heads,
self.req_pool_indices, self.req_pool_indices,
self.seq_lens, self.seq_lens,
forward_batch.seq_lens_sum + (bs - raw_bs), forward_batch.seq_lens_sum + (bs - raw_bs),
......
...@@ -265,7 +265,8 @@ class TestTritonAttention(unittest.TestCase): ...@@ -265,7 +265,8 @@ class TestTritonAttention(unittest.TestCase):
o, o,
kv_indptr, kv_indptr,
kv_indices, kv_indices,
(attn_logits, attn_lse), attn_logits,
attn_lse,
num_kv_splits, num_kv_splits,
max_kv_splits, max_kv_splits,
sm_scale, sm_scale,
...@@ -329,7 +330,8 @@ class TestTritonAttention(unittest.TestCase): ...@@ -329,7 +330,8 @@ class TestTritonAttention(unittest.TestCase):
o, o,
kv_indptr, kv_indptr,
kv_indices, kv_indices,
(attn_logits, attn_lse), attn_logits,
attn_lse,
num_kv_splits, num_kv_splits,
max_kv_splits, max_kv_splits,
sm_scale, sm_scale,
...@@ -353,7 +355,8 @@ class TestTritonAttention(unittest.TestCase): ...@@ -353,7 +355,8 @@ class TestTritonAttention(unittest.TestCase):
o_grouped, o_grouped,
kv_indptr, kv_indptr,
kv_indices, kv_indices,
(attn_logits1, attn_lse1), attn_logits1,
attn_lse1,
num_kv_splits, num_kv_splits,
max_kv_splits, max_kv_splits,
sm_scale, sm_scale,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment