Unverified Commit be1a3cd9 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix swa eagle verify accuracy for Triton backend (#9279)

parent 4b74c3fc
......@@ -35,6 +35,7 @@ class ForwardMetadata:
window_kv_indptr: torch.Tensor
window_kv_indices: torch.Tensor
window_num_kv_splits: torch.Tensor
window_kv_offsets: torch.Tensor
class TritonAttnBackend(AttentionBackend):
......@@ -163,6 +164,7 @@ class TritonAttnBackend(AttentionBackend):
window_kv_indptr = self.window_kv_indptr
window_kv_indices = None
window_num_kv_splits = None
window_kv_offsets = None
spec_info = forward_batch.spec_info
if forward_batch.forward_mode.is_decode_or_idle():
......@@ -186,7 +188,7 @@ class TritonAttnBackend(AttentionBackend):
self.sliding_window_size is not None
and self.sliding_window_size > 0
):
window_kv_indptr, window_kv_indices, window_kv_lens = (
window_kv_indptr, window_kv_indices, window_kv_lens, _ = (
update_sliding_window_buffer(
self.window_kv_indptr,
self.req_to_token,
......@@ -249,17 +251,21 @@ class TritonAttnBackend(AttentionBackend):
)
if self.sliding_window_size is not None and self.sliding_window_size > 0:
window_kv_indptr, window_kv_indices, window_kv_lens = (
update_sliding_window_buffer(
self.window_kv_indptr,
self.req_to_token,
self.sliding_window_size,
forward_batch.seq_lens,
forward_batch.req_pool_indices,
bs,
self.device,
self.token_to_kv_pool_allocator,
)
# window_kv_offsets is used to calculate the start position in custom mask
(
window_kv_indptr,
window_kv_indices,
window_kv_lens,
window_kv_offsets,
) = update_sliding_window_buffer(
self.window_kv_indptr,
self.req_to_token,
self.sliding_window_size,
forward_batch.seq_lens,
forward_batch.req_pool_indices,
bs,
self.device,
self.token_to_kv_pool_allocator,
)
custom_mask = spec_info.custom_mask
......@@ -312,15 +318,17 @@ class TritonAttnBackend(AttentionBackend):
)
# Sliding window
if self.sliding_window_size is not None and self.sliding_window_size > 0:
window_kv_indptr, window_kv_indices, _ = update_sliding_window_buffer(
self.window_kv_indptr,
self.req_to_token,
self.sliding_window_size,
forward_batch.extend_prefix_lens,
forward_batch.req_pool_indices,
bs,
self.device,
self.token_to_kv_pool_allocator,
window_kv_indptr, window_kv_indices, _, _ = (
update_sliding_window_buffer(
self.window_kv_indptr,
self.req_to_token,
self.sliding_window_size,
forward_batch.extend_prefix_lens,
forward_batch.req_pool_indices,
bs,
self.device,
self.token_to_kv_pool_allocator,
)
)
qo_indptr = self.qo_indptr
......@@ -346,6 +354,7 @@ class TritonAttnBackend(AttentionBackend):
window_kv_indptr,
window_kv_indices,
window_num_kv_splits,
window_kv_offsets,
)
def init_cuda_graph_state(
......@@ -400,6 +409,12 @@ class TritonAttnBackend(AttentionBackend):
device=self.device,
)
self.cuda_graph_window_kv_offsets = torch.zeros(
(max_bs,),
dtype=torch.int32,
device=self.device,
)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
......@@ -414,6 +429,7 @@ class TritonAttnBackend(AttentionBackend):
window_kv_indptr = self.window_kv_indptr
window_kv_indices = None
window_num_kv_splits = None
window_kv_offsets = None
if forward_mode.is_decode_or_idle():
if spec_info is None:
......@@ -436,7 +452,7 @@ class TritonAttnBackend(AttentionBackend):
):
window_kv_indices = self.cuda_graph_window_kv_indices
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
window_kv_indptr, window_kv_indices, _ = (
window_kv_indptr, window_kv_indices, _, _ = (
update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr,
window_kv_indices,
......@@ -483,13 +499,14 @@ class TritonAttnBackend(AttentionBackend):
if self.sliding_window_size is not None and self.sliding_window_size > 0:
window_kv_indices = self.cuda_graph_window_kv_indices
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
window_kv_indptr, window_kv_indices, _ = (
window_kv_offsets = self.cuda_graph_window_kv_offsets
window_kv_indptr, window_kv_indices, _, window_kv_offsets[:bs] = (
update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr,
window_kv_indices,
self.req_to_token,
self.sliding_window_size,
seq_lens,
seq_lens[:bs],
req_pool_indices,
bs,
self.token_to_kv_pool_allocator,
......@@ -551,6 +568,7 @@ class TritonAttnBackend(AttentionBackend):
window_kv_indptr,
window_kv_indices,
window_num_kv_splits,
window_kv_offsets,
)
def init_forward_metadata_replay_cuda_graph(
......@@ -589,7 +607,7 @@ class TritonAttnBackend(AttentionBackend):
):
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
window_kv_indices = self.cuda_graph_window_kv_indices
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
_, _, window_kv_lens, _ = update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr,
window_kv_indices,
self.req_to_token,
......@@ -635,15 +653,18 @@ class TritonAttnBackend(AttentionBackend):
if self.sliding_window_size is not None and self.sliding_window_size > 0:
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
window_kv_indices = self.cuda_graph_window_kv_indices
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr,
window_kv_indices,
self.req_to_token,
self.sliding_window_size,
seq_lens,
req_pool_indices,
bs,
self.token_to_kv_pool_allocator,
window_kv_offsets = self.cuda_graph_window_kv_offsets
_, _, window_kv_lens, window_kv_offsets[:bs] = (
update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr,
window_kv_indices,
self.req_to_token,
self.sliding_window_size,
seq_lens[:bs],
req_pool_indices,
bs,
self.token_to_kv_pool_allocator,
)
)
custom_mask = self.cuda_graph_custom_mask
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
......@@ -706,10 +727,12 @@ class TritonAttnBackend(AttentionBackend):
) # Needed for sliding window mask
kv_indptr = self.forward_metadata.window_kv_indptr
kv_indices = self.forward_metadata.window_kv_indices
window_kv_offsets = self.forward_metadata.window_kv_offsets
else:
sliding_window_size = -1
kv_indptr = self.forward_metadata.kv_indptr
kv_indices = self.forward_metadata.kv_indices
window_kv_offsets = None
self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
......@@ -729,6 +752,7 @@ class TritonAttnBackend(AttentionBackend):
layer.logit_cap,
sliding_window_size=sliding_window_size,
sinks=sinks,
window_kv_offsets=window_kv_offsets,
)
return o
......@@ -1011,7 +1035,7 @@ def update_sliding_window_buffer(
window_kv_indices[:kv_last_index]
)
)
return window_kv_indptr, window_kv_indices, window_kv_lens
return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx
def update_sliding_window_buffer_cuda_graph(
......@@ -1048,4 +1072,4 @@ def update_sliding_window_buffer_cuda_graph(
window_kv_indices[:kv_last_index]
)
)
return window_kv_indptr, window_kv_indices, window_kv_lens
return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx
......@@ -190,7 +190,7 @@ def _decode_att_m_fwd(
Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]
batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
batch, head_num = q.shape[0], q.shape[1]
grid = (batch, head_num, MAX_KV_SPLITS)
kv_group_num = q.shape[1] // k_buffer.shape[1]
......@@ -433,7 +433,7 @@ def _decode_grouped_att_m_fwd(
BLOCK_DPE = 0
BLOCK_DV = triton.next_power_of_2(Lv)
batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
batch, head_num = q.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[1]
BLOCK_H = 16
......
......@@ -52,6 +52,7 @@ def _fwd_kernel(
mask_ptr,
mask_indptr,
sink_ptr,
window_kv_offset_ptr,
sm_scale,
kv_group_num,
stride_qbs,
......@@ -95,6 +96,11 @@ def _fwd_kernel(
if USE_CUSTOM_MASK:
cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)
# For SWA, we should only load the mask in the sliding window
window_kv_offset = 0
if USE_CUSTOM_MASK and SLIDING_WINDOW_SIZE > 0:
window_kv_offset = tl.load(window_kv_offset_ptr + cur_seq)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
offs_m = tl.arange(0, BLOCK_M)
......@@ -139,7 +145,9 @@ def _fwd_kernel(
custom_mask = tl.load(
mask_ptr
+ cur_seq_mask_start_idx
+ (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
+ (cur_block_m * BLOCK_M + offs_m[:, None])
* (cur_seq_len + window_kv_offset)
+ window_kv_offset
+ start_n
+ offs_n[None, :],
mask=(mask_m[:, None] & mask_n[None, :]),
......@@ -236,7 +244,9 @@ def _fwd_kernel(
custom_mask = tl.load(
mask_ptr
+ cur_seq_mask_start_idx
+ (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
+ (cur_block_m * BLOCK_M + offs_m[:, None])
* (cur_seq_len + window_kv_offset)
+ window_kv_offset
+ cur_seq_len_prefix
+ start_n
+ offs_n[None, :],
......@@ -362,6 +372,7 @@ def extend_attention_fwd(
skip_prefix_custom_mask=True,
sliding_window_size=-1,
sinks=None,
window_kv_offsets=None,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
......@@ -449,6 +460,7 @@ def extend_attention_fwd(
custom_mask,
mask_indptr,
sinks,
window_kv_offsets,
sm_scale,
kv_group_num,
q_extend.stride(0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment