Unverified Commit be1a3cd9 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix swa eagle verify accuracy for Triton backend (#9279)

parent 4b74c3fc
...@@ -35,6 +35,7 @@ class ForwardMetadata: ...@@ -35,6 +35,7 @@ class ForwardMetadata:
window_kv_indptr: torch.Tensor window_kv_indptr: torch.Tensor
window_kv_indices: torch.Tensor window_kv_indices: torch.Tensor
window_num_kv_splits: torch.Tensor window_num_kv_splits: torch.Tensor
window_kv_offsets: torch.Tensor
class TritonAttnBackend(AttentionBackend): class TritonAttnBackend(AttentionBackend):
...@@ -163,6 +164,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -163,6 +164,7 @@ class TritonAttnBackend(AttentionBackend):
window_kv_indptr = self.window_kv_indptr window_kv_indptr = self.window_kv_indptr
window_kv_indices = None window_kv_indices = None
window_num_kv_splits = None window_num_kv_splits = None
window_kv_offsets = None
spec_info = forward_batch.spec_info spec_info = forward_batch.spec_info
if forward_batch.forward_mode.is_decode_or_idle(): if forward_batch.forward_mode.is_decode_or_idle():
...@@ -186,7 +188,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -186,7 +188,7 @@ class TritonAttnBackend(AttentionBackend):
self.sliding_window_size is not None self.sliding_window_size is not None
and self.sliding_window_size > 0 and self.sliding_window_size > 0
): ):
window_kv_indptr, window_kv_indices, window_kv_lens = ( window_kv_indptr, window_kv_indices, window_kv_lens, _ = (
update_sliding_window_buffer( update_sliding_window_buffer(
self.window_kv_indptr, self.window_kv_indptr,
self.req_to_token, self.req_to_token,
...@@ -249,17 +251,21 @@ class TritonAttnBackend(AttentionBackend): ...@@ -249,17 +251,21 @@ class TritonAttnBackend(AttentionBackend):
) )
if self.sliding_window_size is not None and self.sliding_window_size > 0: if self.sliding_window_size is not None and self.sliding_window_size > 0:
window_kv_indptr, window_kv_indices, window_kv_lens = ( # window_kv_offsets is used to calculate the start position in custom mask
update_sliding_window_buffer( (
self.window_kv_indptr, window_kv_indptr,
self.req_to_token, window_kv_indices,
self.sliding_window_size, window_kv_lens,
forward_batch.seq_lens, window_kv_offsets,
forward_batch.req_pool_indices, ) = update_sliding_window_buffer(
bs, self.window_kv_indptr,
self.device, self.req_to_token,
self.token_to_kv_pool_allocator, self.sliding_window_size,
) forward_batch.seq_lens,
forward_batch.req_pool_indices,
bs,
self.device,
self.token_to_kv_pool_allocator,
) )
custom_mask = spec_info.custom_mask custom_mask = spec_info.custom_mask
...@@ -312,15 +318,17 @@ class TritonAttnBackend(AttentionBackend): ...@@ -312,15 +318,17 @@ class TritonAttnBackend(AttentionBackend):
) )
# Sliding window # Sliding window
if self.sliding_window_size is not None and self.sliding_window_size > 0: if self.sliding_window_size is not None and self.sliding_window_size > 0:
window_kv_indptr, window_kv_indices, _ = update_sliding_window_buffer( window_kv_indptr, window_kv_indices, _, _ = (
self.window_kv_indptr, update_sliding_window_buffer(
self.req_to_token, self.window_kv_indptr,
self.sliding_window_size, self.req_to_token,
forward_batch.extend_prefix_lens, self.sliding_window_size,
forward_batch.req_pool_indices, forward_batch.extend_prefix_lens,
bs, forward_batch.req_pool_indices,
self.device, bs,
self.token_to_kv_pool_allocator, self.device,
self.token_to_kv_pool_allocator,
)
) )
qo_indptr = self.qo_indptr qo_indptr = self.qo_indptr
...@@ -346,6 +354,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -346,6 +354,7 @@ class TritonAttnBackend(AttentionBackend):
window_kv_indptr, window_kv_indptr,
window_kv_indices, window_kv_indices,
window_num_kv_splits, window_num_kv_splits,
window_kv_offsets,
) )
def init_cuda_graph_state( def init_cuda_graph_state(
...@@ -400,6 +409,12 @@ class TritonAttnBackend(AttentionBackend): ...@@ -400,6 +409,12 @@ class TritonAttnBackend(AttentionBackend):
device=self.device, device=self.device,
) )
self.cuda_graph_window_kv_offsets = torch.zeros(
(max_bs,),
dtype=torch.int32,
device=self.device,
)
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, self,
bs: int, bs: int,
...@@ -414,6 +429,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -414,6 +429,7 @@ class TritonAttnBackend(AttentionBackend):
window_kv_indptr = self.window_kv_indptr window_kv_indptr = self.window_kv_indptr
window_kv_indices = None window_kv_indices = None
window_num_kv_splits = None window_num_kv_splits = None
window_kv_offsets = None
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
if spec_info is None: if spec_info is None:
...@@ -436,7 +452,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -436,7 +452,7 @@ class TritonAttnBackend(AttentionBackend):
): ):
window_kv_indices = self.cuda_graph_window_kv_indices window_kv_indices = self.cuda_graph_window_kv_indices
window_num_kv_splits = self.cuda_graph_window_num_kv_splits window_num_kv_splits = self.cuda_graph_window_num_kv_splits
window_kv_indptr, window_kv_indices, _ = ( window_kv_indptr, window_kv_indices, _, _ = (
update_sliding_window_buffer_cuda_graph( update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr, self.window_kv_indptr,
window_kv_indices, window_kv_indices,
...@@ -483,13 +499,14 @@ class TritonAttnBackend(AttentionBackend): ...@@ -483,13 +499,14 @@ class TritonAttnBackend(AttentionBackend):
if self.sliding_window_size is not None and self.sliding_window_size > 0: if self.sliding_window_size is not None and self.sliding_window_size > 0:
window_kv_indices = self.cuda_graph_window_kv_indices window_kv_indices = self.cuda_graph_window_kv_indices
window_num_kv_splits = self.cuda_graph_window_num_kv_splits window_num_kv_splits = self.cuda_graph_window_num_kv_splits
window_kv_indptr, window_kv_indices, _ = ( window_kv_offsets = self.cuda_graph_window_kv_offsets
window_kv_indptr, window_kv_indices, _, window_kv_offsets[:bs] = (
update_sliding_window_buffer_cuda_graph( update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr, self.window_kv_indptr,
window_kv_indices, window_kv_indices,
self.req_to_token, self.req_to_token,
self.sliding_window_size, self.sliding_window_size,
seq_lens, seq_lens[:bs],
req_pool_indices, req_pool_indices,
bs, bs,
self.token_to_kv_pool_allocator, self.token_to_kv_pool_allocator,
...@@ -551,6 +568,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -551,6 +568,7 @@ class TritonAttnBackend(AttentionBackend):
window_kv_indptr, window_kv_indptr,
window_kv_indices, window_kv_indices,
window_num_kv_splits, window_num_kv_splits,
window_kv_offsets,
) )
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
...@@ -589,7 +607,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -589,7 +607,7 @@ class TritonAttnBackend(AttentionBackend):
): ):
window_num_kv_splits = self.cuda_graph_window_num_kv_splits window_num_kv_splits = self.cuda_graph_window_num_kv_splits
window_kv_indices = self.cuda_graph_window_kv_indices window_kv_indices = self.cuda_graph_window_kv_indices
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph( _, _, window_kv_lens, _ = update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr, self.window_kv_indptr,
window_kv_indices, window_kv_indices,
self.req_to_token, self.req_to_token,
...@@ -635,15 +653,18 @@ class TritonAttnBackend(AttentionBackend): ...@@ -635,15 +653,18 @@ class TritonAttnBackend(AttentionBackend):
if self.sliding_window_size is not None and self.sliding_window_size > 0: if self.sliding_window_size is not None and self.sliding_window_size > 0:
window_num_kv_splits = self.cuda_graph_window_num_kv_splits window_num_kv_splits = self.cuda_graph_window_num_kv_splits
window_kv_indices = self.cuda_graph_window_kv_indices window_kv_indices = self.cuda_graph_window_kv_indices
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph( window_kv_offsets = self.cuda_graph_window_kv_offsets
self.window_kv_indptr, _, _, window_kv_lens, window_kv_offsets[:bs] = (
window_kv_indices, update_sliding_window_buffer_cuda_graph(
self.req_to_token, self.window_kv_indptr,
self.sliding_window_size, window_kv_indices,
seq_lens, self.req_to_token,
req_pool_indices, self.sliding_window_size,
bs, seq_lens[:bs],
self.token_to_kv_pool_allocator, req_pool_indices,
bs,
self.token_to_kv_pool_allocator,
)
) )
custom_mask = self.cuda_graph_custom_mask custom_mask = self.cuda_graph_custom_mask
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
...@@ -706,10 +727,12 @@ class TritonAttnBackend(AttentionBackend): ...@@ -706,10 +727,12 @@ class TritonAttnBackend(AttentionBackend):
) # Needed for sliding window mask ) # Needed for sliding window mask
kv_indptr = self.forward_metadata.window_kv_indptr kv_indptr = self.forward_metadata.window_kv_indptr
kv_indices = self.forward_metadata.window_kv_indices kv_indices = self.forward_metadata.window_kv_indices
window_kv_offsets = self.forward_metadata.window_kv_offsets
else: else:
sliding_window_size = -1 sliding_window_size = -1
kv_indptr = self.forward_metadata.kv_indptr kv_indptr = self.forward_metadata.kv_indptr
kv_indices = self.forward_metadata.kv_indices kv_indices = self.forward_metadata.kv_indices
window_kv_offsets = None
self.extend_attention_fwd( self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
...@@ -729,6 +752,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -729,6 +752,7 @@ class TritonAttnBackend(AttentionBackend):
layer.logit_cap, layer.logit_cap,
sliding_window_size=sliding_window_size, sliding_window_size=sliding_window_size,
sinks=sinks, sinks=sinks,
window_kv_offsets=window_kv_offsets,
) )
return o return o
...@@ -1011,7 +1035,7 @@ def update_sliding_window_buffer( ...@@ -1011,7 +1035,7 @@ def update_sliding_window_buffer(
window_kv_indices[:kv_last_index] window_kv_indices[:kv_last_index]
) )
) )
return window_kv_indptr, window_kv_indices, window_kv_lens return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx
def update_sliding_window_buffer_cuda_graph( def update_sliding_window_buffer_cuda_graph(
...@@ -1048,4 +1072,4 @@ def update_sliding_window_buffer_cuda_graph( ...@@ -1048,4 +1072,4 @@ def update_sliding_window_buffer_cuda_graph(
window_kv_indices[:kv_last_index] window_kv_indices[:kv_last_index]
) )
) )
return window_kv_indptr, window_kv_indices, window_kv_lens return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx
...@@ -190,7 +190,7 @@ def _decode_att_m_fwd( ...@@ -190,7 +190,7 @@ def _decode_att_m_fwd(
Lk = k_buffer.shape[-1] Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1] Lv = v_buffer.shape[-1]
batch, head_num = kv_indptr.shape[0] - 1, q.shape[1] batch, head_num = q.shape[0], q.shape[1]
grid = (batch, head_num, MAX_KV_SPLITS) grid = (batch, head_num, MAX_KV_SPLITS)
kv_group_num = q.shape[1] // k_buffer.shape[1] kv_group_num = q.shape[1] // k_buffer.shape[1]
...@@ -433,7 +433,7 @@ def _decode_grouped_att_m_fwd( ...@@ -433,7 +433,7 @@ def _decode_grouped_att_m_fwd(
BLOCK_DPE = 0 BLOCK_DPE = 0
BLOCK_DV = triton.next_power_of_2(Lv) BLOCK_DV = triton.next_power_of_2(Lv)
batch, head_num = kv_indptr.shape[0] - 1, q.shape[1] batch, head_num = q.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[1] kv_group_num = q.shape[1] // k_buffer.shape[1]
BLOCK_H = 16 BLOCK_H = 16
......
...@@ -52,6 +52,7 @@ def _fwd_kernel( ...@@ -52,6 +52,7 @@ def _fwd_kernel(
mask_ptr, mask_ptr,
mask_indptr, mask_indptr,
sink_ptr, sink_ptr,
window_kv_offset_ptr,
sm_scale, sm_scale,
kv_group_num, kv_group_num,
stride_qbs, stride_qbs,
...@@ -95,6 +96,11 @@ def _fwd_kernel( ...@@ -95,6 +96,11 @@ def _fwd_kernel(
if USE_CUSTOM_MASK: if USE_CUSTOM_MASK:
cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq) cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)
# For SWA, we should only load the mask in the sliding window
window_kv_offset = 0
if USE_CUSTOM_MASK and SLIDING_WINDOW_SIZE > 0:
window_kv_offset = tl.load(window_kv_offset_ptr + cur_seq)
offs_d = tl.arange(0, BLOCK_DMODEL) offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV) offs_dv = tl.arange(0, BLOCK_DV)
offs_m = tl.arange(0, BLOCK_M) offs_m = tl.arange(0, BLOCK_M)
...@@ -139,7 +145,9 @@ def _fwd_kernel( ...@@ -139,7 +145,9 @@ def _fwd_kernel(
custom_mask = tl.load( custom_mask = tl.load(
mask_ptr mask_ptr
+ cur_seq_mask_start_idx + cur_seq_mask_start_idx
+ (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len + (cur_block_m * BLOCK_M + offs_m[:, None])
* (cur_seq_len + window_kv_offset)
+ window_kv_offset
+ start_n + start_n
+ offs_n[None, :], + offs_n[None, :],
mask=(mask_m[:, None] & mask_n[None, :]), mask=(mask_m[:, None] & mask_n[None, :]),
...@@ -236,7 +244,9 @@ def _fwd_kernel( ...@@ -236,7 +244,9 @@ def _fwd_kernel(
custom_mask = tl.load( custom_mask = tl.load(
mask_ptr mask_ptr
+ cur_seq_mask_start_idx + cur_seq_mask_start_idx
+ (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len + (cur_block_m * BLOCK_M + offs_m[:, None])
* (cur_seq_len + window_kv_offset)
+ window_kv_offset
+ cur_seq_len_prefix + cur_seq_len_prefix
+ start_n + start_n
+ offs_n[None, :], + offs_n[None, :],
...@@ -362,6 +372,7 @@ def extend_attention_fwd( ...@@ -362,6 +372,7 @@ def extend_attention_fwd(
skip_prefix_custom_mask=True, skip_prefix_custom_mask=True,
sliding_window_size=-1, sliding_window_size=-1,
sinks=None, sinks=None,
window_kv_offsets=None,
): ):
""" """
q_extend, k_extend, v_extend, o_extend: contiguous tensors q_extend, k_extend, v_extend, o_extend: contiguous tensors
...@@ -449,6 +460,7 @@ def extend_attention_fwd( ...@@ -449,6 +460,7 @@ def extend_attention_fwd(
custom_mask, custom_mask,
mask_indptr, mask_indptr,
sinks, sinks,
window_kv_offsets,
sm_scale, sm_scale,
kv_group_num, kv_group_num,
q_extend.stride(0), q_extend.stride(0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment