Unverified Commit be2d985d authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Minor style change of triton backend (#7165)

parent 5b1afa78
...@@ -20,117 +20,6 @@ if TYPE_CHECKING: ...@@ -20,117 +20,6 @@ if TYPE_CHECKING:
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
@triton.jit
def get_num_kv_splits_triton(
num_kv_splits_ptr,
seq_lens_ptr,
num_seq,
num_group,
num_head,
num_kv_head,
max_kv_splits,
device_core_count,
MAX_NUM_SEQ: tl.constexpr,
):
# TODO: this method is tunable, we need more online serving data to tune it
offs_seq = tl.arange(0, MAX_NUM_SEQ)
mask_seq = offs_seq < num_seq
seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)
max_seq_len = tl.max(seq_lens)
seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)
min_seq_len = tl.min(seq_lens)
if max_seq_len * 8 < min_seq_len * 10:
min_seq_len = max_seq_len
max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)
kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)
# NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0
ext_device_core_count = tl.cast(
device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32
)
block_h, num_kv_group = 16, num_head // num_kv_head
if num_kv_group == 1:
token_grid = num_seq * num_group * num_head
else:
# from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
block_h = tl.minimum(block_h, num_kv_group)
token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)
max_kv_splits_2 = tl.minimum(
tl.cdiv(ext_device_core_count, token_grid), max_kv_splits
)
kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)
num_kv_splits = tl.maximum(
tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
)
offs_token = offs_seq * num_group
mask_token = offs_token < num_seq * num_group
for i in range(0, num_group):
tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)
def update_sliding_window_buffer(
window_kv_indptr,
req_to_token,
sliding_window_size,
seq_lens,
req_pool_indices,
bs,
device,
):
window_kv_lens = torch.minimum(
seq_lens,
torch.tensor(sliding_window_size + 1),
)
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
window_kv_indptr = window_kv_indptr[: bs + 1]
window_kv_indices = torch.empty(
window_kv_indptr[-1], dtype=torch.int32, device=device
)
window_kv_start_idx = seq_lens - window_kv_lens
create_flashinfer_kv_indices_triton[(bs,)](
req_to_token,
req_pool_indices,
window_kv_lens,
window_kv_indptr,
window_kv_start_idx,
window_kv_indices,
req_to_token.stride(0),
)
return window_kv_indptr, window_kv_indices, window_kv_lens
def update_sliding_window_buffer_cuda_graph(
window_kv_indptr,
window_kv_indices,
req_to_token,
sliding_window_size,
seq_lens,
req_pool_indices,
bs,
):
window_kv_lens = torch.minimum(
seq_lens,
torch.tensor(sliding_window_size + 1),
)
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
window_kv_indptr = window_kv_indptr[: bs + 1]
window_kv_start_idx = seq_lens - window_kv_lens
create_flashinfer_kv_indices_triton[(bs,)](
req_to_token,
req_pool_indices,
window_kv_lens,
window_kv_indptr,
window_kv_start_idx,
window_kv_indices,
req_to_token.stride(0),
)
return window_kv_indptr, window_kv_lens
@dataclass @dataclass
class ForwardMetadata: class ForwardMetadata:
attn_logits: torch.Tensor attn_logits: torch.Tensor
...@@ -165,8 +54,8 @@ class TritonAttnBackend(AttentionBackend): ...@@ -165,8 +54,8 @@ class TritonAttnBackend(AttentionBackend):
super().__init__() super().__init__()
self.decode_attention_fwd = decode_attention_fwd self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
self.extend_attention_fwd = extend_attention_fwd self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
self.skip_prefill = skip_prefill self.skip_prefill = skip_prefill
...@@ -973,3 +862,114 @@ class TritonMultiStepDraftBackend: ...@@ -973,3 +862,114 @@ class TritonMultiStepDraftBackend:
) )
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
@triton.jit
def get_num_kv_splits_triton(
num_kv_splits_ptr,
seq_lens_ptr,
num_seq,
num_group,
num_head,
num_kv_head,
max_kv_splits,
device_core_count,
MAX_NUM_SEQ: tl.constexpr,
):
# TODO: this method is tunable, we need more online serving data to tune it
offs_seq = tl.arange(0, MAX_NUM_SEQ)
mask_seq = offs_seq < num_seq
seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)
max_seq_len = tl.max(seq_lens)
seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)
min_seq_len = tl.min(seq_lens)
if max_seq_len * 8 < min_seq_len * 10:
min_seq_len = max_seq_len
max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)
kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)
# NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0
ext_device_core_count = tl.cast(
device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32
)
block_h, num_kv_group = 16, num_head // num_kv_head
if num_kv_group == 1:
token_grid = num_seq * num_group * num_head
else:
# from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
block_h = tl.minimum(block_h, num_kv_group)
token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)
max_kv_splits_2 = tl.minimum(
tl.cdiv(ext_device_core_count, token_grid), max_kv_splits
)
kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)
num_kv_splits = tl.maximum(
tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
)
offs_token = offs_seq * num_group
mask_token = offs_token < num_seq * num_group
for i in range(0, num_group):
tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)
def update_sliding_window_buffer(
window_kv_indptr,
req_to_token,
sliding_window_size,
seq_lens,
req_pool_indices,
bs,
device,
):
window_kv_lens = torch.minimum(
seq_lens,
torch.tensor(sliding_window_size + 1),
)
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
window_kv_indptr = window_kv_indptr[: bs + 1]
window_kv_indices = torch.empty(
window_kv_indptr[-1], dtype=torch.int32, device=device
)
window_kv_start_idx = seq_lens - window_kv_lens
create_flashinfer_kv_indices_triton[(bs,)](
req_to_token,
req_pool_indices,
window_kv_lens,
window_kv_indptr,
window_kv_start_idx,
window_kv_indices,
req_to_token.stride(0),
)
return window_kv_indptr, window_kv_indices, window_kv_lens
def update_sliding_window_buffer_cuda_graph(
window_kv_indptr,
window_kv_indices,
req_to_token,
sliding_window_size,
seq_lens,
req_pool_indices,
bs,
):
window_kv_lens = torch.minimum(
seq_lens,
torch.tensor(sliding_window_size + 1),
)
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
window_kv_indptr = window_kv_indptr[: bs + 1]
window_kv_start_idx = seq_lens - window_kv_lens
create_flashinfer_kv_indices_triton[(bs,)](
req_to_token,
req_pool_indices,
window_kv_lens,
window_kv_indptr,
window_kv_start_idx,
window_kv_indices,
req_to_token.stride(0),
)
return window_kv_indptr, window_kv_lens
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment