Unverified Commit c2f212d6 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

optimize MiniMax-Text-01 lightning_attn_decode triton (#2966)

parent e2cdc8a5
......@@ -23,7 +23,10 @@ def _decode_kernel(
h: tl.constexpr,
n: tl.constexpr,
d: tl.constexpr,
d_original: tl.constexpr,
e: tl.constexpr,
e_original: tl.constexpr,
BLOCK_SIZE: tl.constexpr = 32,
):
off_bh = tl.program_id(0)
off_h = off_bh % h
......@@ -39,21 +42,38 @@ def _decode_kernel(
d_idx = tl.arange(0, d)
e_idx = tl.arange(0, e)
q = tl.load(Q + qk_offset + d_idx)
k = tl.load(K + qk_offset + d_idx)
v = tl.load(V + v_offset + e_idx)
# Create masks for original dimensions
d_mask = d_idx < d_original
e_mask = e_idx < e_original
kv = tl.load(KV + kv_offset + d_idx[:, None] * e + e_idx[None, :])
# Load with masking
q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0)
k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0)
v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0)
# Load KV with 2D masking
kv = tl.load(
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
mask=(d_mask[:, None] & e_mask[None, :]),
other=0.0,
)
# Compute outer product using element-wise operations
k_v_prod = k[:, None] * v[None, :]
kv = ratio * kv + k_v_prod
# Store KV with 2D masking
tl.store(
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], kv.to(KV.dtype.element_ty)
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
kv.to(KV.dtype.element_ty),
mask=(d_mask[:, None] & e_mask[None, :]),
)
# Compute matrix-vector multiplication using element-wise operations and reduction
o = tl.sum(q[:, None] * kv, axis=0)
tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty))
# Store output with masking
tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask)
def lightning_attn_decode(q, k, v, kv, s):
......@@ -62,26 +82,27 @@ def lightning_attn_decode(q, k, v, kv, s):
e = v.shape[-1]
assert n == 1, "Sequence length must be 1 in decode mode"
# Pad dimensions to power of 2
# Get padded dimensions (power of 2)
d_padded = next_power_of_2(d)
e_padded = next_power_of_2(e)
# Pad inputs
q_padded = F.pad(q, (0, d_padded - d))
k_padded = F.pad(k, (0, d_padded - d))
v_padded = F.pad(v, (0, e_padded - e))
kv_padded = F.pad(kv, (0, e_padded - e, 0, d_padded - d))
# Ensure inputs are contiguous
q_padded = q_padded.contiguous()
k_padded = k_padded.contiguous()
v_padded = v_padded.contiguous()
kv_padded = kv_padded.contiguous().to(torch.float32)
s = s.contiguous()
# Create output tensor (padded)
o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
# Create padded tensors without actually padding the data
q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device)
k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device)
v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
kv_padded = torch.empty(
b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device
)
# Copy data to padded tensors
q_padded[..., :d] = q
k_padded[..., :d] = k
v_padded[..., :e] = v
kv_padded[..., :d, :e] = kv
# Launch kernel
grid = (b * h, 1)
_decode_kernel[grid](
......@@ -95,10 +116,12 @@ def lightning_attn_decode(q, k, v, kv, s):
h=h,
n=n,
d=d_padded,
d_original=d,
e=e_padded,
e_original=e,
)
# Remove padding
# Get unpadded outputs
o = o_padded[..., :e]
kv_out = kv_padded[..., :d, :e]
......@@ -351,6 +374,8 @@ def test_lightning_attention_implementations(model_params):
msg="Lightning attention implementations produce different kv results",
)
print("✅ Two implementations match")
def _build_slope_tensor(n_attention_heads: int):
def get_slopes(n):
......@@ -375,7 +400,7 @@ def _build_slope_tensor(n_attention_heads: int):
def get_benchmark():
batch_size_range = [2**i for i in range(0, 12)] # max 2048
batch_size_range = [i for i in range(1, 33)] # max 32
seq_length_range = [1] # decode mode sequence length is fixed to 1
configs = list(itertools.product(batch_size_range, seq_length_range))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment