Unverified Commit 399e7ec8 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Refine naming (#8868)

parent 1bd53168
......@@ -686,7 +686,7 @@ class TritonAttnBackend(AttentionBackend):
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
sk=None,
sinks=None,
):
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
......@@ -731,7 +731,7 @@ class TritonAttnBackend(AttentionBackend):
layer.scaling,
layer.logit_cap,
sliding_window_size=sliding_window_size,
sk=sk,
sinks=sinks,
)
return o
......@@ -743,7 +743,7 @@ class TritonAttnBackend(AttentionBackend):
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
sk=None,
sinks=None,
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
......@@ -780,7 +780,7 @@ class TritonAttnBackend(AttentionBackend):
self.max_kv_splits,
layer.scaling,
layer.logit_cap,
sk=sk,
sinks=sinks,
)
return o
......
......@@ -495,7 +495,7 @@ def _fwd_kernel_stage2(
O,
kv_indptr,
num_kv_splits,
sk_ptr,
sink_ptr,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
......@@ -505,7 +505,7 @@ def _fwd_kernel_stage2(
MIN_BLOCK_KV: tl.constexpr,
BLOCK_DV: tl.constexpr,
Lv: tl.constexpr,
HAS_SK: tl.constexpr,
HAS_SINK: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
......@@ -547,9 +547,9 @@ def _fwd_kernel_stage2(
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
if HAS_SK:
cur_sk = tl.load(sk_ptr + cur_head)
e_sum += tl.exp(cur_sk - e_max)
if HAS_SINK:
cur_sink = tl.load(sink_ptr + cur_head)
e_sum += tl.exp(cur_sink - e_max)
tl.store(
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
......@@ -567,14 +567,14 @@ def _decode_softmax_reducev_fwd(
kv_indptr,
num_kv_splits,
max_kv_splits,
sk=None,
sinks=None,
):
batch, head_num = q.shape[0], q.shape[1]
Lv = v_buffer.shape[-1]
BLOCK_DV = triton.next_power_of_2(Lv)
MAX_KV_SPLITS = max_kv_splits
HAS_SK = sk is not None
HAS_SINK = sinks is not None
extra_kargs = {}
if _is_hip:
......@@ -589,7 +589,7 @@ def _decode_softmax_reducev_fwd(
o,
kv_indptr,
num_kv_splits,
sk,
sinks,
logits.stride(0),
logits.stride(1),
logits.stride(2),
......@@ -599,7 +599,7 @@ def _decode_softmax_reducev_fwd(
MIN_BLOCK_KV=_MIN_BLOCK_KV,
BLOCK_DV=BLOCK_DV,
Lv=Lv,
HAS_SK=HAS_SK,
HAS_SINK=HAS_SINK,
num_warps=4,
num_stages=2,
**extra_kargs,
......@@ -619,7 +619,7 @@ def decode_attention_fwd_normal(
max_kv_splits,
sm_scale,
logit_cap=0.0,
sk=None,
sinks=None,
):
_decode_att_m_fwd(
q,
......@@ -643,7 +643,7 @@ def decode_attention_fwd_normal(
kv_indptr,
num_kv_splits,
max_kv_splits,
sk,
sinks,
)
......@@ -660,7 +660,7 @@ def decode_attention_fwd_grouped(
max_kv_splits,
sm_scale,
logit_cap=0.0,
sk=None,
sinks=None,
):
_decode_grouped_att_m_fwd(
q,
......@@ -684,7 +684,7 @@ def decode_attention_fwd_grouped(
kv_indptr,
num_kv_splits,
max_kv_splits,
sk,
sinks,
)
......@@ -701,7 +701,7 @@ def decode_attention_fwd(
max_kv_splits,
sm_scale,
logit_cap=0.0,
sk=None,
sinks=None,
):
assert max_kv_splits == attn_logits.shape[2]
assert q.shape[0] <= kv_indptr.shape[0] - 1
......@@ -724,7 +724,7 @@ def decode_attention_fwd(
max_kv_splits,
sm_scale,
logit_cap=logit_cap,
sk=sk,
sinks=sinks,
)
else:
# GQA/MQA/MLA
......@@ -741,5 +741,5 @@ def decode_attention_fwd(
max_kv_splits,
sm_scale,
logit_cap=logit_cap,
sk=sk,
sinks=sinks,
)
......@@ -51,7 +51,7 @@ def _fwd_kernel(
kv_indices,
mask_ptr,
mask_indptr,
sk_ptr,
sink_ptr,
sm_scale,
kv_group_num,
stride_qbs,
......@@ -79,7 +79,7 @@ def _fwd_kernel(
IS_CAUSAL: tl.constexpr,
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
STORE_TRANSPOSE: tl.constexpr,
HAS_SK: tl.constexpr,
HAS_SINK: tl.constexpr,
):
cur_seq = tl.program_id(0)
cur_head = tl.program_id(1)
......@@ -302,9 +302,9 @@ def _fwd_kernel(
e_max = n_e_max
if HAS_SK:
cur_sk = tl.load(sk_ptr + cur_head)
deno += tl.exp(cur_sk - e_max)
if HAS_SINK:
cur_sink = tl.load(sink_ptr + cur_head)
deno += tl.exp(cur_sink - e_max)
offs_o = (
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
......@@ -344,7 +344,7 @@ def extend_attention_fwd(
logit_cap=0.0,
skip_prefix_custom_mask=True,
sliding_window_size=-1,
sk=None,
sinks=None,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
......@@ -410,7 +410,7 @@ def extend_attention_fwd(
# Skip custom mask for prefix part
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
HAS_SK = sk is not None
HAS_SINK = sinks is not None
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
num_stages = 1
......@@ -431,7 +431,7 @@ def extend_attention_fwd(
kv_indices,
custom_mask,
mask_indptr,
sk,
sinks,
sm_scale,
kv_group_num,
q_extend.stride(0),
......@@ -458,7 +458,7 @@ def extend_attention_fwd(
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
IS_CAUSAL=is_causal,
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
HAS_SK=HAS_SK,
HAS_SINK=HAS_SINK,
STORE_TRANSPOSE=_is_hip,
num_warps=num_warps,
num_stages=num_stages,
......
......@@ -301,7 +301,7 @@ class GptOssAttention(nn.Module):
hidden_states, forward_batch, inner_state = intermediate_state
if inner_state is None:
return hidden_states
attn_output = self.attn(*inner_state, sk=self.sinks)
attn_output = self.attn(*inner_state, sinks=self.sinks)
output, _ = self.o_proj(attn_output)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment