Unverified Commit 399e7ec8 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Refine naming (#8868)

parent 1bd53168
...@@ -686,7 +686,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -686,7 +686,7 @@ class TritonAttnBackend(AttentionBackend):
layer: RadixAttention, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
sk=None, sinks=None,
): ):
# TODO: reuse the buffer across layers # TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim: if layer.qk_head_dim != layer.v_head_dim:
...@@ -731,7 +731,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -731,7 +731,7 @@ class TritonAttnBackend(AttentionBackend):
layer.scaling, layer.scaling,
layer.logit_cap, layer.logit_cap,
sliding_window_size=sliding_window_size, sliding_window_size=sliding_window_size,
sk=sk, sinks=sinks,
) )
return o return o
...@@ -743,7 +743,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -743,7 +743,7 @@ class TritonAttnBackend(AttentionBackend):
layer: RadixAttention, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
sk=None, sinks=None,
): ):
# During torch.compile, there is a bug in rotary_emb that causes the # During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly. # output value to have a 3D tensor shape. This reshapes the output correctly.
...@@ -780,7 +780,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -780,7 +780,7 @@ class TritonAttnBackend(AttentionBackend):
self.max_kv_splits, self.max_kv_splits,
layer.scaling, layer.scaling,
layer.logit_cap, layer.logit_cap,
sk=sk, sinks=sinks,
) )
return o return o
......
...@@ -495,7 +495,7 @@ def _fwd_kernel_stage2( ...@@ -495,7 +495,7 @@ def _fwd_kernel_stage2(
O, O,
kv_indptr, kv_indptr,
num_kv_splits, num_kv_splits,
sk_ptr, sink_ptr,
stride_mid_ob, stride_mid_ob,
stride_mid_oh, stride_mid_oh,
stride_mid_os, stride_mid_os,
...@@ -505,7 +505,7 @@ def _fwd_kernel_stage2( ...@@ -505,7 +505,7 @@ def _fwd_kernel_stage2(
MIN_BLOCK_KV: tl.constexpr, MIN_BLOCK_KV: tl.constexpr,
BLOCK_DV: tl.constexpr, BLOCK_DV: tl.constexpr,
Lv: tl.constexpr, Lv: tl.constexpr,
HAS_SK: tl.constexpr, HAS_SINK: tl.constexpr,
): ):
cur_batch = tl.program_id(0) cur_batch = tl.program_id(0)
cur_head = tl.program_id(1) cur_head = tl.program_id(1)
...@@ -547,9 +547,9 @@ def _fwd_kernel_stage2( ...@@ -547,9 +547,9 @@ def _fwd_kernel_stage2(
e_sum = e_sum * old_scale + exp_logic e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max e_max = n_e_max
if HAS_SK: if HAS_SINK:
cur_sk = tl.load(sk_ptr + cur_head) cur_sink = tl.load(sink_ptr + cur_head)
e_sum += tl.exp(cur_sk - e_max) e_sum += tl.exp(cur_sink - e_max)
tl.store( tl.store(
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
...@@ -567,14 +567,14 @@ def _decode_softmax_reducev_fwd( ...@@ -567,14 +567,14 @@ def _decode_softmax_reducev_fwd(
kv_indptr, kv_indptr,
num_kv_splits, num_kv_splits,
max_kv_splits, max_kv_splits,
sk=None, sinks=None,
): ):
batch, head_num = q.shape[0], q.shape[1] batch, head_num = q.shape[0], q.shape[1]
Lv = v_buffer.shape[-1] Lv = v_buffer.shape[-1]
BLOCK_DV = triton.next_power_of_2(Lv) BLOCK_DV = triton.next_power_of_2(Lv)
MAX_KV_SPLITS = max_kv_splits MAX_KV_SPLITS = max_kv_splits
HAS_SK = sk is not None HAS_SINK = sinks is not None
extra_kargs = {} extra_kargs = {}
if _is_hip: if _is_hip:
...@@ -589,7 +589,7 @@ def _decode_softmax_reducev_fwd( ...@@ -589,7 +589,7 @@ def _decode_softmax_reducev_fwd(
o, o,
kv_indptr, kv_indptr,
num_kv_splits, num_kv_splits,
sk, sinks,
logits.stride(0), logits.stride(0),
logits.stride(1), logits.stride(1),
logits.stride(2), logits.stride(2),
...@@ -599,7 +599,7 @@ def _decode_softmax_reducev_fwd( ...@@ -599,7 +599,7 @@ def _decode_softmax_reducev_fwd(
MIN_BLOCK_KV=_MIN_BLOCK_KV, MIN_BLOCK_KV=_MIN_BLOCK_KV,
BLOCK_DV=BLOCK_DV, BLOCK_DV=BLOCK_DV,
Lv=Lv, Lv=Lv,
HAS_SK=HAS_SK, HAS_SINK=HAS_SINK,
num_warps=4, num_warps=4,
num_stages=2, num_stages=2,
**extra_kargs, **extra_kargs,
...@@ -619,7 +619,7 @@ def decode_attention_fwd_normal( ...@@ -619,7 +619,7 @@ def decode_attention_fwd_normal(
max_kv_splits, max_kv_splits,
sm_scale, sm_scale,
logit_cap=0.0, logit_cap=0.0,
sk=None, sinks=None,
): ):
_decode_att_m_fwd( _decode_att_m_fwd(
q, q,
...@@ -643,7 +643,7 @@ def decode_attention_fwd_normal( ...@@ -643,7 +643,7 @@ def decode_attention_fwd_normal(
kv_indptr, kv_indptr,
num_kv_splits, num_kv_splits,
max_kv_splits, max_kv_splits,
sk, sinks,
) )
...@@ -660,7 +660,7 @@ def decode_attention_fwd_grouped( ...@@ -660,7 +660,7 @@ def decode_attention_fwd_grouped(
max_kv_splits, max_kv_splits,
sm_scale, sm_scale,
logit_cap=0.0, logit_cap=0.0,
sk=None, sinks=None,
): ):
_decode_grouped_att_m_fwd( _decode_grouped_att_m_fwd(
q, q,
...@@ -684,7 +684,7 @@ def decode_attention_fwd_grouped( ...@@ -684,7 +684,7 @@ def decode_attention_fwd_grouped(
kv_indptr, kv_indptr,
num_kv_splits, num_kv_splits,
max_kv_splits, max_kv_splits,
sk, sinks,
) )
...@@ -701,7 +701,7 @@ def decode_attention_fwd( ...@@ -701,7 +701,7 @@ def decode_attention_fwd(
max_kv_splits, max_kv_splits,
sm_scale, sm_scale,
logit_cap=0.0, logit_cap=0.0,
sk=None, sinks=None,
): ):
assert max_kv_splits == attn_logits.shape[2] assert max_kv_splits == attn_logits.shape[2]
assert q.shape[0] <= kv_indptr.shape[0] - 1 assert q.shape[0] <= kv_indptr.shape[0] - 1
...@@ -724,7 +724,7 @@ def decode_attention_fwd( ...@@ -724,7 +724,7 @@ def decode_attention_fwd(
max_kv_splits, max_kv_splits,
sm_scale, sm_scale,
logit_cap=logit_cap, logit_cap=logit_cap,
sk=sk, sinks=sinks,
) )
else: else:
# GQA/MQA/MLA # GQA/MQA/MLA
...@@ -741,5 +741,5 @@ def decode_attention_fwd( ...@@ -741,5 +741,5 @@ def decode_attention_fwd(
max_kv_splits, max_kv_splits,
sm_scale, sm_scale,
logit_cap=logit_cap, logit_cap=logit_cap,
sk=sk, sinks=sinks,
) )
...@@ -51,7 +51,7 @@ def _fwd_kernel( ...@@ -51,7 +51,7 @@ def _fwd_kernel(
kv_indices, kv_indices,
mask_ptr, mask_ptr,
mask_indptr, mask_indptr,
sk_ptr, sink_ptr,
sm_scale, sm_scale,
kv_group_num, kv_group_num,
stride_qbs, stride_qbs,
...@@ -79,7 +79,7 @@ def _fwd_kernel( ...@@ -79,7 +79,7 @@ def _fwd_kernel(
IS_CAUSAL: tl.constexpr, IS_CAUSAL: tl.constexpr,
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr, SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
STORE_TRANSPOSE: tl.constexpr, STORE_TRANSPOSE: tl.constexpr,
HAS_SK: tl.constexpr, HAS_SINK: tl.constexpr,
): ):
cur_seq = tl.program_id(0) cur_seq = tl.program_id(0)
cur_head = tl.program_id(1) cur_head = tl.program_id(1)
...@@ -302,9 +302,9 @@ def _fwd_kernel( ...@@ -302,9 +302,9 @@ def _fwd_kernel(
e_max = n_e_max e_max = n_e_max
if HAS_SK: if HAS_SINK:
cur_sk = tl.load(sk_ptr + cur_head) cur_sink = tl.load(sink_ptr + cur_head)
deno += tl.exp(cur_sk - e_max) deno += tl.exp(cur_sink - e_max)
offs_o = ( offs_o = (
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
...@@ -344,7 +344,7 @@ def extend_attention_fwd( ...@@ -344,7 +344,7 @@ def extend_attention_fwd(
logit_cap=0.0, logit_cap=0.0,
skip_prefix_custom_mask=True, skip_prefix_custom_mask=True,
sliding_window_size=-1, sliding_window_size=-1,
sk=None, sinks=None,
): ):
""" """
q_extend, k_extend, v_extend, o_extend: contiguous tensors q_extend, k_extend, v_extend, o_extend: contiguous tensors
...@@ -410,7 +410,7 @@ def extend_attention_fwd( ...@@ -410,7 +410,7 @@ def extend_attention_fwd(
# Skip custom mask for prefix part # Skip custom mask for prefix part
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
HAS_SK = sk is not None HAS_SINK = sinks is not None
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M)) grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
num_stages = 1 num_stages = 1
...@@ -431,7 +431,7 @@ def extend_attention_fwd( ...@@ -431,7 +431,7 @@ def extend_attention_fwd(
kv_indices, kv_indices,
custom_mask, custom_mask,
mask_indptr, mask_indptr,
sk, sinks,
sm_scale, sm_scale,
kv_group_num, kv_group_num,
q_extend.stride(0), q_extend.stride(0),
...@@ -458,7 +458,7 @@ def extend_attention_fwd( ...@@ -458,7 +458,7 @@ def extend_attention_fwd(
USE_CUSTOM_MASK=USE_CUSTOM_MASK, USE_CUSTOM_MASK=USE_CUSTOM_MASK,
IS_CAUSAL=is_causal, IS_CAUSAL=is_causal,
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK, SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
HAS_SK=HAS_SK, HAS_SINK=HAS_SINK,
STORE_TRANSPOSE=_is_hip, STORE_TRANSPOSE=_is_hip,
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
......
...@@ -301,7 +301,7 @@ class GptOssAttention(nn.Module): ...@@ -301,7 +301,7 @@ class GptOssAttention(nn.Module):
hidden_states, forward_batch, inner_state = intermediate_state hidden_states, forward_batch, inner_state = intermediate_state
if inner_state is None: if inner_state is None:
return hidden_states return hidden_states
attn_output = self.attn(*inner_state, sk=self.sinks) attn_output = self.attn(*inner_state, sinks=self.sinks)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment