"examples/vscode:/vscode.git/clone" did not exist on "0052f1212b2f57a6f659cc38201186e63bbd4c76"
Unverified Commit c2f212d6 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

optimize MiniMax-Text-01 lightning_attn_decode triton (#2966)

parent e2cdc8a5
...@@ -23,7 +23,10 @@ def _decode_kernel( ...@@ -23,7 +23,10 @@ def _decode_kernel(
h: tl.constexpr, h: tl.constexpr,
n: tl.constexpr, n: tl.constexpr,
d: tl.constexpr, d: tl.constexpr,
d_original: tl.constexpr,
e: tl.constexpr, e: tl.constexpr,
e_original: tl.constexpr,
BLOCK_SIZE: tl.constexpr = 32,
): ):
off_bh = tl.program_id(0) off_bh = tl.program_id(0)
off_h = off_bh % h off_h = off_bh % h
...@@ -39,21 +42,38 @@ def _decode_kernel( ...@@ -39,21 +42,38 @@ def _decode_kernel(
d_idx = tl.arange(0, d) d_idx = tl.arange(0, d)
e_idx = tl.arange(0, e) e_idx = tl.arange(0, e)
q = tl.load(Q + qk_offset + d_idx) # Create masks for original dimensions
k = tl.load(K + qk_offset + d_idx) d_mask = d_idx < d_original
v = tl.load(V + v_offset + e_idx) e_mask = e_idx < e_original
kv = tl.load(KV + kv_offset + d_idx[:, None] * e + e_idx[None, :]) # Load with masking
q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0)
k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0)
v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0)
# Load KV with 2D masking
kv = tl.load(
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
mask=(d_mask[:, None] & e_mask[None, :]),
other=0.0,
)
# Compute outer product using element-wise operations
k_v_prod = k[:, None] * v[None, :] k_v_prod = k[:, None] * v[None, :]
kv = ratio * kv + k_v_prod kv = ratio * kv + k_v_prod
# Store KV with 2D masking
tl.store( tl.store(
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], kv.to(KV.dtype.element_ty) KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
kv.to(KV.dtype.element_ty),
mask=(d_mask[:, None] & e_mask[None, :]),
) )
# Compute matrix-vector multiplication using element-wise operations and reduction
o = tl.sum(q[:, None] * kv, axis=0) o = tl.sum(q[:, None] * kv, axis=0)
tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty))
# Store output with masking
tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask)
def lightning_attn_decode(q, k, v, kv, s): def lightning_attn_decode(q, k, v, kv, s):
...@@ -62,26 +82,27 @@ def lightning_attn_decode(q, k, v, kv, s): ...@@ -62,26 +82,27 @@ def lightning_attn_decode(q, k, v, kv, s):
e = v.shape[-1] e = v.shape[-1]
assert n == 1, "Sequence length must be 1 in decode mode" assert n == 1, "Sequence length must be 1 in decode mode"
# Pad dimensions to power of 2 # Get padded dimensions (power of 2)
d_padded = next_power_of_2(d) d_padded = next_power_of_2(d)
e_padded = next_power_of_2(e) e_padded = next_power_of_2(e)
# Pad inputs
q_padded = F.pad(q, (0, d_padded - d))
k_padded = F.pad(k, (0, d_padded - d))
v_padded = F.pad(v, (0, e_padded - e))
kv_padded = F.pad(kv, (0, e_padded - e, 0, d_padded - d))
# Ensure inputs are contiguous
q_padded = q_padded.contiguous()
k_padded = k_padded.contiguous()
v_padded = v_padded.contiguous()
kv_padded = kv_padded.contiguous().to(torch.float32)
s = s.contiguous()
# Create output tensor (padded) # Create output tensor (padded)
o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device) o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
# Create padded tensors without actually padding the data
q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device)
k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device)
v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
kv_padded = torch.empty(
b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device
)
# Copy data to padded tensors
q_padded[..., :d] = q
k_padded[..., :d] = k
v_padded[..., :e] = v
kv_padded[..., :d, :e] = kv
# Launch kernel # Launch kernel
grid = (b * h, 1) grid = (b * h, 1)
_decode_kernel[grid]( _decode_kernel[grid](
...@@ -95,10 +116,12 @@ def lightning_attn_decode(q, k, v, kv, s): ...@@ -95,10 +116,12 @@ def lightning_attn_decode(q, k, v, kv, s):
h=h, h=h,
n=n, n=n,
d=d_padded, d=d_padded,
d_original=d,
e=e_padded, e=e_padded,
e_original=e,
) )
# Remove padding # Get unpadded outputs
o = o_padded[..., :e] o = o_padded[..., :e]
kv_out = kv_padded[..., :d, :e] kv_out = kv_padded[..., :d, :e]
...@@ -351,6 +374,8 @@ def test_lightning_attention_implementations(model_params): ...@@ -351,6 +374,8 @@ def test_lightning_attention_implementations(model_params):
msg="Lightning attention implementations produce different kv results", msg="Lightning attention implementations produce different kv results",
) )
print("✅ Two implementations match")
def _build_slope_tensor(n_attention_heads: int): def _build_slope_tensor(n_attention_heads: int):
def get_slopes(n): def get_slopes(n):
...@@ -375,7 +400,7 @@ def _build_slope_tensor(n_attention_heads: int): ...@@ -375,7 +400,7 @@ def _build_slope_tensor(n_attention_heads: int):
def get_benchmark(): def get_benchmark():
batch_size_range = [2**i for i in range(0, 12)] # max 2048 batch_size_range = [i for i in range(1, 33)] # max 32
seq_length_range = [1] # decode mode sequence length is fixed to 1 seq_length_range = [1] # decode mode sequence length is fixed to 1
configs = list(itertools.product(batch_size_range, seq_length_range)) configs = list(itertools.product(batch_size_range, seq_length_range))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment