Unverified Commit bc7c4d20 authored by Aleksandr Malyshev's avatar Aleksandr Malyshev Committed by GitHub
Browse files

[Kernel][ROCM] Upstream prefix prefill speed up for vLLM V1 (#13305)


Signed-off-by: default avatarSage Moore <sage@neuralmagic.com>
Signed-off-by: default avatarroot <root@banff-cyxtera-s73-5.ctr.dcgpu>
Signed-off-by: default avatarAleksandr Malyshev <maleksan@amd.com>
Signed-off-by: default avatarroot <root@banff-cyxtera-s65-4.amd.com>
Signed-off-by: default avatarmaleksan85 <maleksan@amd.com>
Signed-off-by: <>
Co-authored-by: default avatarSage Moore <sage@neuralmagic.com>
Co-authored-by: default avatarroot <root@banff-cyxtera-s73-5.ctr.dcgpu>
Co-authored-by: default avatarAleksandr Malyshev <maleksan@amd.com>
Co-authored-by: default avatarqli88 <qiang.li2@amd.com>
Co-authored-by: default avatarroot <root@banff-cyxtera-s65-4.amd.com>
parent f67e9e9f
...@@ -195,15 +195,15 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, ...@@ -195,15 +195,15 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
]) ])
@pytest.mark.parametrize("per_test_common_llm_kwargs", @pytest.mark.parametrize("per_test_common_llm_kwargs",
[{ [{
"block_size": 8, "block_size": 16,
"max_num_batched_tokens": 2, "max_num_batched_tokens": 2,
"max_num_seqs": 2, "max_num_seqs": 2,
}, { }, {
"block_size": 8, "block_size": 16,
"max_num_batched_tokens": 3, "max_num_batched_tokens": 3,
"max_num_seqs": 2, "max_num_seqs": 2,
}, { }, {
"block_size": 8, "block_size": 16,
"max_num_batched_tokens": 256, "max_num_batched_tokens": 256,
"max_num_seqs": 10, "max_num_seqs": 10,
}]) }])
......
...@@ -16,831 +16,778 @@ NUM_WARPS = 4 if current_platform.is_rocm() else 8 ...@@ -16,831 +16,778 @@ NUM_WARPS = 4 if current_platform.is_rocm() else 8
# To check compatibility # To check compatibility
IS_TURING = current_platform.get_device_capability() == (7, 5) IS_TURING = current_platform.get_device_capability() == (7, 5)
if triton.__version__ >= "2.1.0":
# Here's an example autotuner config for this kernel. This config does provide
@triton.jit # a performance improvement, but dramatically increases first call latency in
def _fwd_kernel( # triton 3.2. Because of this tradeoff, it's currently commented out.
Q, # @triton.autotune(
K, # configs=[
V, # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \
K_cache, # "num_unroll_cache": 4, \
V_cache, # "num_unroll_request": 1 } | \
B_Loc, # ({"kpack": 2, "waves_per_eu": 2} \
sm_scale, # if current_platform.is_rocm() else {}), \
k_scale, # num_warps=4, \
v_scale, # num_stages=1)
B_Start_Loc, # ],
B_Seqlen, # key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"]
block_size, # )
x, @triton.jit
Out, def _fwd_kernel(Q,
stride_b_loc_b, K,
stride_b_loc_s, V,
stride_qbs, K_cache,
stride_qh, V_cache,
stride_qd, B_Loc,
stride_kbs, sm_scale,
stride_kh, k_scale,
stride_kd, v_scale,
stride_vbs, B_Start_Loc,
stride_vh, B_Seqlen,
stride_vd, x: tl.constexpr,
stride_obs, Out,
stride_oh, stride_b_loc_b,
stride_od, stride_b_loc_s,
stride_k_cache_bs, stride_qbs,
stride_k_cache_h, stride_qh,
stride_k_cache_d, stride_qd,
stride_k_cache_bl, stride_kbs,
stride_k_cache_x, stride_kh,
stride_v_cache_bs, stride_kd,
stride_v_cache_h, stride_vbs,
stride_v_cache_d, stride_vh,
stride_v_cache_bl, stride_vd,
num_queries_per_kv: int, stride_obs,
IN_PRECISION: tl.constexpr, stride_oh,
BLOCK_M: tl.constexpr, stride_od,
BLOCK_DMODEL: tl.constexpr, # head size stride_k_cache_bs,
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 stride_k_cache_h,
BLOCK_N: tl.constexpr, stride_k_cache_d,
SLIDING_WINDOW: tl.constexpr, stride_k_cache_bl: tl.constexpr,
SKIP_DECODE: tl.constexpr, stride_k_cache_x,
): stride_v_cache_bs,
stride_v_cache_h,
cur_batch = tl.program_id(0) stride_v_cache_d,
cur_head = tl.program_id(1) stride_v_cache_bl,
start_m = tl.program_id(2) num_queries_per_kv: tl.constexpr,
IN_PRECISION: tl.constexpr,
cur_kv_head = cur_head // num_queries_per_kv BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) BLOCK_DMODEL_PADDED: tl.constexpr,
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) BLOCK_SIZE: tl.constexpr,
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) BLOCK_N: tl.constexpr,
cur_batch_query_len = (cur_batch_in_all_stop_index - SLIDING_WINDOW: tl.constexpr,
cur_batch_in_all_start_index) num_unroll_cache: tl.constexpr,
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len num_unroll_request: tl.constexpr,
SKIP_DECODE: tl.constexpr,
if SKIP_DECODE and cur_batch_query_len == 1: MAX_Q_LEN: tl.constexpr = 0,
return MAX_CTX_LEN: tl.constexpr = 0):
# start position inside of the query cur_batch = tl.program_id(0)
# generally, N goes over kv, while M goes over query_len cur_head = tl.program_id(1)
block_start_loc = BLOCK_M * start_m start_m = tl.program_id(2)
# initialize offsets cur_kv_head = cur_head // num_queries_per_kv
# [N]; starts at 0
offs_n = tl.arange(0, BLOCK_N) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
# [D]; starts at 0 cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
# [M]; starts at current position in query cur_batch_query_len = (cur_batch_in_all_stop_index -
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) cur_batch_in_all_start_index)
# [M,D] cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + if SKIP_DECODE and cur_batch_query_len == 1:
cur_head * stride_qh + offs_d[None, :] * stride_qd)
dim_mask = tl.where(
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,
0).to(tl.int1) # [D]
q = tl.load(Q + off_q,
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_query_len),
other=0.0) # [M,D]
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M]
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M]
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED],
dtype=tl.float32) # [M,D]
# compute query against context (no causal mask here)
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0) # [N]
# [D,N]
off_k = (bn[None, :] * stride_k_cache_bs +
cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
# [N,D]
off_v = (
bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k_load = tl.load(K_cache + off_k,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
other=0.0) # [D,N]
if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else:
k = k_load
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N]
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
if SLIDING_WINDOW > 0:
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
# Q entries in sequence
# (start_n + offs_n[None, :]) are the positions of
# KV entries in sequence
# So the condition makes sure each entry in Q only attends
# to KV entries not more than SLIDING_WINDOW away.
#
# We can't use -inf here, because the
# sliding window may lead to the entire row being masked.
# This then makes m_ij contain -inf, which causes NaNs in
# exp().
qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
(start_n + offs_n[None, :]) < SLIDING_WINDOW, qk,
-10000)
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1) # [M]
p = tl.exp(qk - m_ij[:, None]) # [M,N]
l_ij = tl.sum(p, 1) # [M]
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij) # [M]
alpha = tl.exp(m_i - m_i_new) # [M]
beta = tl.exp(m_ij - m_i_new) # [M]
l_i_new = alpha * l_i + beta * l_ij # [M]
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v_load = tl.load(V_cache + off_v,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0) # [N,D]
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
else:
v = v_load
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
# # update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
# block_mask is 0 when we're already past the current query length
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
# compute query against itself (with causal mask)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_query_len),
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk *= sm_scale
# apply causal mask
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
if SLIDING_WINDOW > 0:
qk = tl.where(
offs_m[:, None] - (start_n + offs_n[None, :])
< SLIDING_WINDOW, qk, -10000)
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_query_len),
other=0.0)
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_query_len))
return return
@triton.jit # start position inside of the query
def _fwd_kernel_flash_attn_v2( # generally, N goes over kv, while M goes over query_len
Q, block_start_loc = BLOCK_M * start_m
K,
V, # initialize offsets
K_cache, # [BLOCK_SIZE]; starts at 0
V_cache, offs_bs_n = tl.arange(0, BLOCK_SIZE)
B_Loc, # [N]; starts at 0
sm_scale, offs_n = tl.arange(0, BLOCK_N)
B_Start_Loc, # [D]; starts at 0
B_Seqlen, offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
B_Ctxlen, # [M]; starts at current position in query
block_size, offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
x, # [M,D]
Out, off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
stride_b_loc_b, cur_head * stride_qh + offs_d[None, :] * stride_qd)
stride_b_loc_s,
stride_qbs, dim_mask = tl.where(
stride_qh, tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,
stride_qd, 0).to(tl.int1) # [D]
stride_kbs,
stride_kh, q = tl.load(Q + off_q,
stride_kd, mask=dim_mask[None, :] &
stride_vbs, (offs_m[:, None] < cur_batch_query_len),
stride_vh, other=0.0) # [M,D]
stride_vd,
stride_obs, # initialize pointer to m and l
stride_oh, m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
stride_od, l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
stride_k_cache_bs, acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D]
stride_k_cache_h,
stride_k_cache_d, # compute query against context (no causal mask here)
stride_k_cache_bl, for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \
stride_k_cache_x, loop_unroll_factor=num_unroll_cache):
stride_v_cache_bs, start_n = tl.multiple_of(start_n, BLOCK_SIZE)
stride_v_cache_h, # -- compute qk ----
stride_v_cache_d, bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
stride_v_cache_bl, (start_n // BLOCK_SIZE) * stride_b_loc_s)
num_queries_per_kv: int, # [D,BLOCK_SIZE]
BLOCK_M: tl.constexpr, off_k = (
BLOCK_DMODEL: tl.constexpr, bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
BLOCK_N: tl.constexpr, (offs_d[:, None] // x) * stride_k_cache_d +
): ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl +
cur_batch = tl.program_id(0) (offs_d[:, None] % x) * stride_k_cache_x)
cur_head = tl.program_id(1)
start_m = tl.program_id(2) # [BLOCK_SIZE,D]
off_v = (bn[:, None] * stride_v_cache_bs +
cur_kv_head = cur_head // num_queries_per_kv cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) offs_bs_n[:, None] * stride_v_cache_bl)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) if start_n + BLOCK_SIZE > cur_batch_ctx_len or \
BLOCK_DMODEL != BLOCK_DMODEL_PADDED:
block_start_loc = BLOCK_M * start_m k_load = tl.load(
K_cache + off_k,
# initialize offsets mask=dim_mask[:, None] &
offs_n = tl.arange(0, BLOCK_N) ((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len),
offs_d = tl.arange(0, BLOCK_DMODEL) other=0.0) # [D,N]
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) else:
off_q = ( k_load = tl.load(K_cache + off_k)
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd) if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
q = tl.load(Q + off_q, else:
mask=offs_m[:, None] k = k_load
< cur_batch_seq_len - cur_batch_ctx_len,
qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N]
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
if SLIDING_WINDOW > 0:
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
# Q entries in sequence
# (start_n + offs_bs_n[None, :]) are the positions of
# KV entries in sequence
# So the condition makes sure each entry in Q only attends
# to KV entries not more than SLIDING_WINDOW away.
#
# We can't use -inf here, because the
# sliding window may lead to the entire row being masked.
# This then makes m_ij contain -inf, which causes NaNs in
# exp().
qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
(start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, qk,
-10000)
# compute running maximum
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, axis=1)
alpha = tl.exp(m_i - m_ij)
acc = acc * alpha[:, None]
# update acc
if start_n + BLOCK_SIZE > cur_batch_ctx_len or \
BLOCK_DMODEL != BLOCK_DMODEL_PADDED:
v_load = tl.load(
V_cache + off_v,
mask=dim_mask[None, :] &
((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len),
other=0.0) # [N,D]
else:
v_load = tl.load(V_cache + off_v)
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
else:
v = v_load
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
# # update m_i and l_i
l_i = l_i * alpha + l_ij
m_i = m_ij
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
# block_mask is 0 when we're already past the current query length
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
# compute query against itself (with causal mask)
for start_n in tl.range(0, \
block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \
loop_unroll_factor=num_unroll_request):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_query_len),
other=0.0) other=0.0)
# # initialize pointer to m and l qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) qk *= sm_scale
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # apply causal mask
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
for start_n in range(0, cur_batch_ctx_len, BLOCK_N): float("-inf"))
start_n = tl.multiple_of(start_n, BLOCK_N) if SLIDING_WINDOW > 0:
# -- compute qk ---- qk = tl.where(
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW,
((start_n + offs_n) // block_size) * stride_b_loc_s, qk, -10000)
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0) # compute running maximum
off_k = (bn[None, :] * stride_k_cache_bs + m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
cur_kv_head * stride_k_cache_h + p = tl.exp(qk - m_ij[:, None])
(offs_d[:, None] // x) * stride_k_cache_d + l_ij = tl.sum(p, axis=1)
((start_n + offs_n[None, :]) % block_size) * alpha = tl.exp(m_i - m_ij)
stride_k_cache_bl + acc = acc * alpha[:, None]
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = ( # update acc
bn[:, None] * stride_v_cache_bs + v = tl.load(v_ptrs +
cur_kv_head * stride_v_cache_h + (cur_batch_in_all_start_index + start_n) * stride_vbs,
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_mask = tl.where(
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :])
< cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None])
< cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# acc /= l_i[:, None]
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
return
@triton.jit
def _fwd_kernel_alibi(
Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
k_scale,
v_scale,
B_Start_Loc,
B_Seqlen,
Alibi_slopes,
block_size,
x,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: int,
IN_PRECISION: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, # head size
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
BLOCK_N: tl.constexpr,
SKIP_DECODE: tl.constexpr,
):
# attn_bias[]
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv
# cur_batch_seq_len: the length of prompts
# cur_batch_ctx_len: the length of prefix
# cur_batch_in_all_start_index: the start id of the dim=0
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
cur_batch_query_len = (cur_batch_in_all_stop_index -
cur_batch_in_all_start_index)
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
if SKIP_DECODE and cur_batch_query_len == 1:
return
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
dim_mask = tl.where(
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
q = tl.load(Q + off_q,
mask=dim_mask[None, :] & mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), ((start_n + offs_n[:, None]) < cur_batch_query_len),
other=0.0)
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
# update m_i and l_i
l_i = l_i * alpha + l_ij
m_i = m_ij
acc = acc / l_i[:, None]
# initialize pointers to output
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len))
return
@triton.jit
def _fwd_kernel_flash_attn_v2(
Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
block_size,
x,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: int,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
q = tl.load(Q + off_q,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
# # initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_mask = tl.where(
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :])
< cur_batch_seq_len - cur_batch_ctx_len,
other=0.0) other=0.0)
# # initialize pointer to m and l qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") qk += tl.dot(q, k)
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) qk *= sm_scale
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange( # -- compute m_ij, p, l_ij
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len m_ij = tl.max(qk, 1)
alibi_start_k = 0 m_i_new = tl.maximum(m_i, m_ij)
for start_n in range(0, cur_batch_ctx_len, BLOCK_N): p = tl.math.exp(qk - m_i_new[:, None])
start_n = tl.multiple_of(start_n, BLOCK_N) l_ij = tl.sum(p, 1)
# -- compute qk ---- # -- update m_i and l_i
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s, alpha = tl.math.exp(m_i - m_i_new)
mask=(start_n + offs_n) < cur_batch_ctx_len, l_i_new = alpha * l_i + l_ij
other=0) # -- update output accumulator --
off_k = (bn[None, :] * stride_k_cache_bs + # scale p
cur_kv_head * stride_k_cache_h + # scale acc
(offs_d[:, None] // x) * stride_k_cache_d + acc_scale = alpha
((start_n + offs_n[None, :]) % block_size) * # acc_scale = l_i / l_i_new * alpha
stride_k_cache_bl + acc = acc * acc_scale[:, None]
(offs_d[:, None] % x) * stride_k_cache_x) # update acc
off_v = ( v = tl.load(v_ptrs +
bn[:, None] * stride_v_cache_bs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
cur_kv_head * stride_v_cache_h + mask=(start_n + offs_n[:, None])
offs_d[None, :] * stride_v_cache_d + < cur_batch_seq_len - cur_batch_ctx_len,
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) other=0.0)
k_load = tl.load(K_cache + off_k,
mask=dim_mask[:, None] & p = p.to(v.dtype)
((start_n + offs_n[None, :]) < cur_batch_ctx_len), acc += tl.dot(p, v)
other=0.0) # [D,N] # update m_i and l_i
l_i = l_i_new
if k_load.dtype.is_fp8(): m_i = m_i_new
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else: # acc /= l_i[:, None]
k = k_load # initialize pointers to output
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) cur_head * stride_oh + offs_d[None, :] * stride_od)
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) out_ptrs = Out + off_o
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, tl.store(out_ptrs,
float("-inf")) acc,
qk *= sm_scale mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
return
# load alibi
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
alibi_start_q[:, None]) * alibi_slope @triton.jit
alibi = tl.where( def _fwd_kernel_alibi(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), Q,
alibi, float("-inf")) K,
qk += alibi V,
alibi_start_k += BLOCK_N K_cache,
V_cache,
# -- compute m_ij, p, l_ij B_Loc,
m_ij = tl.max(qk, 1) sm_scale,
m_i_new = tl.maximum(m_i, m_ij) k_scale,
p = tl.math.exp(qk - m_i_new[:, None]) v_scale,
l_ij = tl.sum(p, 1) B_Start_Loc,
# -- update m_i and l_i B_Seqlen,
Alibi_slopes,
alpha = tl.math.exp(m_i - m_i_new) block_size,
l_i_new = alpha * l_i + l_ij x,
# -- update output accumulator -- Out,
# scale p stride_b_loc_b,
# scale acc stride_b_loc_s,
acc_scale = alpha stride_qbs,
# acc_scale = l_i / l_i_new * alpha stride_qh,
acc = acc * acc_scale[:, None] stride_qd,
# update acc stride_kbs,
v_load = tl.load(V_cache + off_v, stride_kh,
mask=dim_mask[None, :] & stride_kd,
((start_n + offs_n[:, None]) < cur_batch_ctx_len), stride_vbs,
other=0.0) stride_vh,
if v_load.dtype.is_fp8(): stride_vd,
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) stride_obs,
else: stride_oh,
v = v_load stride_od,
p = p.to(v.dtype) stride_k_cache_bs,
stride_k_cache_h,
acc = tl.dot(p, v, acc=acc, input_precision='ieee') stride_k_cache_d,
# update m_i and l_i stride_k_cache_bl,
l_i = l_i_new stride_k_cache_x,
m_i = m_i_new stride_v_cache_bs,
stride_v_cache_h,
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + stride_v_cache_d,
offs_d[:, None] * stride_kd) stride_v_cache_bl,
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + num_queries_per_kv: int,
offs_d[None, :] * stride_vd) IN_PRECISION: tl.constexpr,
k_ptrs = K + off_k BLOCK_M: tl.constexpr,
v_ptrs = V + off_v BLOCK_DMODEL: tl.constexpr, # head size
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
block_mask = tl.where( BLOCK_N: tl.constexpr,
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) SKIP_DECODE: tl.constexpr,
):
# init alibi # attn_bias[]
alibi_slope = tl.load(Alibi_slopes + cur_head) cur_batch = tl.program_id(0)
alibi_start_q = tl.arange( cur_head = tl.program_id(1)
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len start_m = tl.program_id(2)
alibi_start_k = cur_batch_ctx_len
# # init debugger cur_kv_head = cur_head // num_queries_per_kv
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
# offset_db_k = tl.arange(0, BLOCK_N) # cur_batch_seq_len: the length of prompts
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] # cur_batch_ctx_len: the length of prefix
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): # cur_batch_in_all_start_index: the start id of the dim=0
start_n = tl.multiple_of(start_n, BLOCK_N) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
# -- compute qk ---- cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
k = tl.load(k_ptrs + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
(cur_batch_in_all_start_index + start_n) * stride_kbs, cur_batch_query_len = (cur_batch_in_all_stop_index -
mask=dim_mask[:, None] & cur_batch_in_all_start_index)
((start_n + offs_n[None, :]) cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
< cur_batch_seq_len - cur_batch_ctx_len),
other=0.0) if SKIP_DECODE and cur_batch_query_len == 1:
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k, acc=qk, input_precision='ieee')
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
# load alibi
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
alibi_start_q[:, None]) * alibi_slope
alibi = tl.where(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
alibi, float("-inf"))
qk += alibi
alibi_start_k += BLOCK_N
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None])
< cur_batch_seq_len - cur_batch_ctx_len),
other=0.0)
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
acc = acc / l_i[:, None]
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))
return return
@torch.inference_mode() block_start_loc = BLOCK_M * start_m
def context_attention_fwd(q,
k, # initialize offsets
v, offs_n = tl.arange(0, BLOCK_N)
o, offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
kv_cache_dtype: str, offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
k_cache, off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
v_cache, cur_head * stride_qh + offs_d[None, :] * stride_qd)
b_loc,
b_start_loc, dim_mask = tl.where(
b_seq_len, tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
max_seq_len,
max_input_len, q = tl.load(Q + off_q,
k_scale: torch.Tensor, mask=dim_mask[None, :] &
v_scale: torch.Tensor, (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
alibi_slopes=None, other=0.0)
sliding_window=None,
sm_scale=None, # # initialize pointer to m and l
skip_decode=False): m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
q_dtype_is_f32 = q.dtype is torch.float32 acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_k = 0
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k_load = tl.load(K_cache + off_k,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
other=0.0) # [D,N]
if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else:
k = k_load
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
# load alibi
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
alibi_start_q[:, None]) * alibi_slope
alibi = tl.where(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi,
float("-inf"))
qk += alibi
alibi_start_k += BLOCK_N
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v_load = tl.load(V_cache + off_v,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0)
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
else:
v = v_load
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_mask = tl.where(
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
# init alibi
alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_k = cur_batch_ctx_len
# # init debugger
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
# offset_db_k = tl.arange(0, BLOCK_N)
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=dim_mask[:, None] & ((start_n + offs_n[None, :])
< cur_batch_seq_len - cur_batch_ctx_len),
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k, acc=qk, input_precision='ieee')
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
# load alibi
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
alibi_start_q[:, None]) * alibi_slope
alibi = tl.where(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi,
float("-inf"))
qk += alibi
alibi_start_k += BLOCK_N
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=dim_mask[None, :] & ((start_n + offs_n[:, None])
< cur_batch_seq_len - cur_batch_ctx_len),
other=0.0)
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
acc = acc / l_i[:, None]
# initialize pointers to output
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))
return
@torch.inference_mode()
def context_attention_fwd(q,
k,
v,
o,
kv_cache_dtype: str,
k_cache,
v_cache,
b_loc,
b_start_loc,
b_seq_len,
max_seq_len,
max_input_len,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
alibi_slopes=None,
sliding_window=None,
sm_scale=None,
skip_decode=False):
q_dtype_is_f32 = q.dtype is torch.float32
# Turing does have tensor core for float32 multiplication
# use ieee as fallback for triton kernels work. There is also
# warning on vllm/config.py to inform users this fallback
# implementation
IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None
# Conversion of FP8 Tensor from uint8 storage to
# appropriate torch.dtype for interpretation by Triton
if "fp8" in kv_cache_dtype:
assert (k_cache.dtype == torch.uint8)
assert (v_cache.dtype == torch.uint8)
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
target_dtype = current_platform.fp8_dtype()
elif kv_cache_dtype == "fp8_e5m2":
target_dtype = torch.float8_e5m2
else:
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
k_cache = k_cache.view(target_dtype)
v_cache = v_cache.view(target_dtype)
if (k_cache.dtype == torch.uint8
or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"):
raise ValueError("kv_cache_dtype='auto' unsupported for\
FP8 KV Cache prefill kernel")
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
# round up Lk to a power of 2 - this is required for Triton block size
Lk_padded = triton.next_power_of_2(Lk)
if sm_scale is None:
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
num_queries_per_kv = q.shape[1] // k.shape[1]
assert batch + 1 == len(b_start_loc)
# 0 means "disable"
if sliding_window is None or sliding_window <= 0:
sliding_window = 0
if alibi_slopes is not None:
# need to reduce num. blocks when using fp32 # need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory # due to increased use of GPU shared memory
# if q.dtype is torch.float32: # if q.dtype is torch.float32:
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK
# batch, head,
# Turing does have tensor core for float32 multiplication grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
# use ieee as fallback for triton kernels work. There is also _fwd_kernel_alibi[grid](
# warning on vllm/config.py to inform users this fallback
# implementation
IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None
# Conversion of FP8 Tensor from uint8 storage to
# appropriate torch.dtype for interpretation by Triton
if "fp8" in kv_cache_dtype:
assert (k_cache.dtype == torch.uint8)
assert (v_cache.dtype == torch.uint8)
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
target_dtype = current_platform.fp8_dtype()
elif kv_cache_dtype == "fp8_e5m2":
target_dtype = torch.float8_e5m2
else:
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
k_cache = k_cache.view(target_dtype)
v_cache = v_cache.view(target_dtype)
if (k_cache.dtype == torch.uint8
or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"):
raise ValueError("kv_cache_dtype='auto' unsupported for\
FP8 KV Cache prefill kernel")
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
# round up Lk to a power of 2 - this is required for Triton block size
Lk_padded = triton.next_power_of_2(Lk)
if sm_scale is None:
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
num_queries_per_kv = q.shape[1] // k.shape[1]
assert batch + 1 == len(b_start_loc)
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
# 0 means "disable"
if sliding_window is None or sliding_window <= 0:
sliding_window = 0
if alibi_slopes is not None:
_fwd_kernel_alibi[grid](
q,
k,
v,
k_cache,
v_cache,
b_loc,
sm_scale,
k_scale,
v_scale,
b_start_loc,
b_seq_len,
alibi_slopes,
v_cache.shape[3],
k_cache.shape[4],
o,
b_loc.stride(0),
b_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
k_cache.stride(
4
), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(
3), #[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv=num_queries_per_kv,
IN_PRECISION=IN_PRECISION,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded,
BLOCK_N=BLOCK,
SKIP_DECODE=skip_decode,
num_warps=NUM_WARPS,
num_stages=1,
)
return
_fwd_kernel[grid](
q, q,
k, k,
v, v,
...@@ -852,6 +799,7 @@ if triton.__version__ >= "2.1.0": ...@@ -852,6 +799,7 @@ if triton.__version__ >= "2.1.0":
v_scale, v_scale,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
alibi_slopes,
v_cache.shape[3], v_cache.shape[3],
k_cache.shape[4], k_cache.shape[4],
o, o,
...@@ -886,9 +834,69 @@ if triton.__version__ >= "2.1.0": ...@@ -886,9 +834,69 @@ if triton.__version__ >= "2.1.0":
BLOCK_DMODEL=Lk, BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_DMODEL_PADDED=Lk_padded,
BLOCK_N=BLOCK, BLOCK_N=BLOCK,
SLIDING_WINDOW=sliding_window,
SKIP_DECODE=skip_decode, SKIP_DECODE=skip_decode,
num_warps=NUM_WARPS, num_warps=NUM_WARPS,
num_stages=1, num_stages=1,
) )
return return
max_seq_len = 0 if max_seq_len is None else max_seq_len
extra_kargs = {}
if current_platform.is_rocm():
extra_kargs = {"kpack": 2, "waves_per_eu": 2}
grid = lambda META: (batch, head,
triton.cdiv(max_input_len, META["BLOCK_M"]))
_fwd_kernel[grid](
q,
k,
v,
k_cache,
v_cache,
b_loc,
sm_scale,
k_scale,
v_scale,
b_start_loc,
b_seq_len,
k_cache.shape[4],
o,
b_loc.stride(0),
b_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
k_cache.stride(
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size]
BLOCK_SIZE=v_cache.shape[3],
num_queries_per_kv=num_queries_per_kv,
IN_PRECISION=IN_PRECISION,
BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded,
SLIDING_WINDOW=sliding_window,
SKIP_DECODE=skip_decode,
BLOCK_M=128,
BLOCK_N=64,
num_unroll_cache=4,
num_unroll_request=1,
num_warps=4,
num_stages=1,
**extra_kargs)
return
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment