Unverified Commit 8cc26acd authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Performance] Improve Triton prefill attention kernel's performance (#32403)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 4a6af881
...@@ -46,7 +46,7 @@ def test_bert_models( ...@@ -46,7 +46,7 @@ def test_bert_models(
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = hf_output.detach().clone().cpu().float() hf_output = hf_output.detach().clone().cpu().float()
vllm_output = vllm_output.detach().clone().cpu().float() vllm_output = vllm_output.detach().clone().cpu().float()
torch.testing.assert_close(hf_output, vllm_output, atol=1.2e-2, rtol=1e-3) torch.testing.assert_close(hf_output, vllm_output, atol=3.2e-2, rtol=1e-3)
@pytest.mark.parametrize("model", ["disham993/electrical-ner-ModernBERT-base"]) @pytest.mark.parametrize("model", ["disham993/electrical-ner-ModernBERT-base"])
...@@ -86,7 +86,7 @@ def test_modernbert_models( ...@@ -86,7 +86,7 @@ def test_modernbert_models(
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = hf_output.detach().clone().cpu().float() hf_output = hf_output.detach().clone().cpu().float()
vllm_output = vllm_output.detach().clone().cpu().float() vllm_output = vllm_output.detach().clone().cpu().float()
torch.testing.assert_close(hf_output, vllm_output, atol=1.2e-2, rtol=1e-3) torch.testing.assert_close(hf_output, vllm_output, atol=3.2e-2, rtol=1e-3)
@pytest.mark.parametrize("model", ["bd2lcco/Qwen3-0.6B-finetuned"]) @pytest.mark.parametrize("model", ["bd2lcco/Qwen3-0.6B-finetuned"])
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Math utility functions for vLLM.""" """Math utility functions for vLLM."""
# Approximate value of 1/ln(2), used for log/exp base conversion
# Best FP32 approximation: 1.4426950216 (hex 0x3FB8AA3B)
RCP_LN2 = 1.4426950216
def cdiv(a: int, b: int) -> int: def cdiv(a: int, b: int) -> int:
"""Ceiling division.""" """Ceiling division."""
......
...@@ -30,6 +30,7 @@ import torch ...@@ -30,6 +30,7 @@ import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import RCP_LN2
@triton.jit @triton.jit
...@@ -110,15 +111,7 @@ def _fwd_kernel( ...@@ -110,15 +111,7 @@ def _fwd_kernel(
end_n_limit = block_mask * end_n end_n_limit = block_mask * end_n
for start_n in range(start_n_limit, end_n_limit, BLOCK_N): for start_n in range(start_n_limit, end_n_limit, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N) # -- prepare attention mask ----
# -- compute qk ----
k = tl.load(
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]),
other=0.0,
)
# Apply attention mask (causal + bidirectional sliding window)
# Position indices in the sequence # Position indices in the sequence
pos_q = offs_m[:, None] # Query positions [BLOCK_M, 1] pos_q = offs_m[:, None] # Query positions [BLOCK_M, 1]
pos_k = start_n + offs_n[None, :] # Key positions [1, BLOCK_N] pos_k = start_n + offs_n[None, :] # Key positions [1, BLOCK_N]
...@@ -141,53 +134,38 @@ def _fwd_kernel( ...@@ -141,53 +134,38 @@ def _fwd_kernel(
if sliding_mask_k is not None: if sliding_mask_k is not None:
mask &= sliding_mask_k mask &= sliding_mask_k
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) start_n = tl.multiple_of(start_n, BLOCK_N)
qk += tl.where(mask, 0, float("-inf")) # -- compute qk ----
qk += tl.dot(q, k) k = tl.load(
qk *= sm_scale k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(pos_k < cur_batch_seq_len) & (mask_d[:, None]),
# -- compute m_ij, p, l_ij other=0.0,
m_ij = tl.max(qk, 1) )
# For sliding window there's a chance the max is -inf due to masking of
# the entire row. In this case we need to set m_j 0 to avoid NaN qk = tl.dot(q, k)
m_ij_valid_mask = m_ij > float("-inf") qk = tl.where(mask, qk * sm_scale, -1.0e8)
m_ij_masked = tl.where(m_ij_valid_mask, m_ij, 0.0) m_ij = tl.maximum(m_i, tl.max(qk, 1))
# -- compute p and l_ij -- qk -= m_ij[:, None]
p = tl.exp(qk - m_ij_masked[:, None]) p = tl.math.exp2(qk)
l_ij = tl.sum(p, 1) l_ij = tl.sum(p, 1)
# -- update m_i and l_i # -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij) alpha = tl.math.exp2(m_i - m_ij)
m_i_new_mask = m_i_new > float("-inf") l_i = l_i * alpha + l_ij
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
# mask alpha and beta for sliding window
alpha = tl.where(m_i_new_mask, alpha, 1.0)
beta = tl.where(m_i_new_mask, beta, 0.0)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator -- # -- update output accumulator --
# scale p acc = acc * alpha[:, None]
# For sliding window there's a chance the l_i_new is 0 due to masking
# the entire row. We need to set l_i_new 1 to avoid zero division
l_i_new_mask = (l_i_new != 0.0) & (m_i_new_mask > float("-inf"))
l_i_new_safe = tl.where(l_i_new_mask, l_i_new, 1.0)
p_scale = beta / l_i_new_safe
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new_safe * alpha
acc = acc * acc_scale[:, None]
# update acc # update acc
v = tl.load( v = tl.load(
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]), mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]),
other=0.0, other=0.0,
) )
p = p.to(v.dtype) p = p.to(v.dtype)
acc += tl.dot(p, v) acc = tl.dot(p, v, acc)
# update m_i and l_i # update m_i
l_i = l_i_new m_i = m_ij
m_i = m_i_new
# initialize pointers to output acc = acc / l_i[:, None]
off_o = ( off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
+ cur_head * stride_oh + cur_head * stride_oh
...@@ -234,6 +212,9 @@ def context_attention_fwd( ...@@ -234,6 +212,9 @@ def context_attention_fwd(
Lq, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1] Lq, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1]
sm_scale = 1.0 / (Lq**0.5) if softmax_scale is None else softmax_scale sm_scale = 1.0 / (Lq**0.5) if softmax_scale is None else softmax_scale
# rescale with 1/ln(2) for triton exp2
sm_scale *= RCP_LN2
batch, head = b_seq_len.shape[0], q.shape[1] batch, head = b_seq_len.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k.shape[1] kv_group_num = q.shape[1] // k.shape[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment