"git@developer.sourcefind.cn:change/sglang.git" did not exist on "24247b416875bfedf367c06a8fae6843abe4d11f"
Commit 3cbf8cbc authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Example] Implement TileLang Native Sparse Attention Kernel (#121)

* Add DeepSeek MLA decode example with Flash Attention implementation

* Add GEMM SplitK and StreamK example implementations

This commit introduces two new example scripts demonstrating advanced GEMM (matrix multiplication) techniques:
- `example_tilelang_gemm_splitk.py`: Implements a Split-K GEMM kernel using TileLang
- `example_tilelang_gemm_streamk.py`: Implements a Stream-K GEMM kernel using TileLang

Both examples showcase different parallel computation strategies for matrix multiplication, with comprehensive testing using PyTorch reference implementations.

* Refactor GEMM SplitK and StreamK example implementations

Clean up and improve code formatting for the SplitK and StreamK GEMM example scripts:
- Remove unused import (Profiler) in splitk example
- Simplify line breaks and improve code readability
- Standardize indentation and remove unnecessary whitespace
- Optimize atomic add and copy operations for better clarity

* Add block sparse attention benchmarks for multiple libraries

This commit introduces comprehensive block sparse attention benchmarks for different libraries:
- TileLang block sparse FMHA implementation
- Triton block sparse FMHA implementation
- PyTorch reference block sparse FMHA implementation
- FlashAttention dense FMHA reference implementation

The benchmarks include:
- Configurable benchmark parameters (batch size, heads, sequence length, etc.)
- Sparse mask generation using top-k and threshold methods
- Performance measurement for different sparse attention configurations
- Utility functions for mask generation and benchmarking

* Refactor block sparse attention benchmarks with code style improvements

- Add Ruff linter ignore comments to benchmark files
- Improve code formatting and line breaks
- Remove unused imports
- Standardize print statement formatting
- Enhance code readability across multiple library benchmarks

* lint fix

* Add CUDA atomic operations for BFLOAT16 and update function naming

- Implement AtomicAdd functions for BFLOAT16 and BFLOAT16x2 in CUDA common header
- Rename existing atomic add functions to use PascalCase (atomicAdd -> AtomicAdd)
- Add a new __pack_nv_bfloat162 function for packing BFLOAT16 values
- Update kernel and language customization to use new function names
- Add return type annotations in profiler module

* lint fix

* Add example for Group Query Attention (GQA) forward pass using Flash Attention in TileLang

This commit introduces a new example script `example_gqa_fwd_bshd.py` that demonstrates:
- Group Query Attention (GQA) implementation
- Flash Attention forward pass
- Performance benchmarking
- Configurable parameters for batch, heads, sequence length, and dimension
- Autotuning support
- Reference implementation comparison

* Refactor IR lowering pipeline into modular phases

This commit introduces a new module `phase.py` to modularize the IR lowering process by splitting the complex lowering pipeline into two distinct phases:
- `LowerAndLegalize`: Handles initial IR legalization and transformation
- `OptimizeForTarget`: Applies target-specific optimizations

The changes simplify the lowering logic in multiple files by extracting the transformation steps into reusable functions, improving code readability and maintainability.

* lintfix

* nas kernel

* Enhance Native Sparse Attention Examples with Code Improvements and Parameter Updates

- Updated example_tilelang_nsa.py and example_triton_nsa.py with code formatting and style improvements
- Increased default number of heads and selected blocks in TileLang NSA example
- Added Ruff linter ignore comments to reference.py
- Standardized function signatures and improved code readability across NSA implementations

* Add utility math functions for integer operations

- Implement `next_power_of_2()` to calculate the next power of 2 for an integer
- Add `cdiv()` function for ceiling division of integers
parent 2b97e98a
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
from reference import naive_nsa
import tilelang
from tilelang import language as T
import tilelang.testing
tilelang.testing.set_random_seed(0)
def native_sparse_attention(batch,
heads,
seq_len,
dim,
is_causal,
scale=None,
groups=1,
selected_blocks=16):
if scale is None:
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
block_indices_shape = [batch, seq_len, head_kv, selected_blocks]
block_indices_dtype = "int32"
dtype = "float16"
accum_dtype = "float"
block_S = 64
block_T = min(128, tilelang.math.next_power_of_2(dim))
S = selected_blocks
NS = S
G = groups
BS = block_S
BK = BV = block_T
num_stages = 0
threads = 32
def kernel_func(block_S, block_T, num_stages, threads):
@T.macro
def MMA0(
K: T.Buffer(kv_shape, dtype),
Q_shared: T.Buffer([G, BK], dtype),
K_shared: T.Buffer([BS, BK], dtype),
acc_s: T.Buffer([G, BS], accum_dtype),
i_s: T.int32,
i_b: T.int32,
i_h: T.int32,
i_t: T.int32,
):
T.copy(K[i_b, i_s * BS:(i_s + 1) * BS, i_h, :], K_shared)
if is_causal:
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.if_then_else(i_t >= (i_s * BS + j), 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Buffer(kv_shape, dtype),
V_shared: T.Buffer([G, BV], dtype),
acc_s_cast: T.Buffer([G, BS], dtype),
acc_o: T.Buffer([G, BV], accum_dtype),
i_s: T.int32,
i_b: T.int32,
i_h: T.int32,
):
T.copy(V[i_b, i_s * BS:(i_s + 1) * BS, i_h, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.Buffer([G, BS], accum_dtype),
acc_s_cast: T.Buffer([G, BS], dtype),
scores_max: T.Buffer([G], accum_dtype),
scores_max_prev: T.Buffer([G], accum_dtype),
scores_scale: T.Buffer([G], accum_dtype),
scores_sum: T.Buffer([G], accum_dtype),
logsum: T.Buffer([G], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=True)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(G):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(G, BS):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(G):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.Buffer([G, BV], accum_dtype),
scores_scale: T.Buffer([G], accum_dtype),
):
for i, j in T.Parallel(G, BV):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Buffer(q_shape, dtype),
K: T.Buffer(kv_shape, dtype),
V: T.Buffer(kv_shape, dtype),
BlockIndices: T.Buffer(block_indices_shape, block_indices_dtype),
Output: T.Buffer(q_shape, dtype),
):
with T.Kernel(
dim // block_T, seq_len, batch * head_kv, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([G, BK], dtype)
K_shared = T.alloc_shared([BS, BK], dtype)
V_shared = T.alloc_shared([BS, BV], dtype)
O_shared = T.alloc_shared([G, BV], dtype)
acc_s = T.alloc_fragment([G, BS], accum_dtype)
acc_s_cast = T.alloc_fragment([G, BS], dtype)
acc_o = T.alloc_fragment([G, BV], accum_dtype)
scores_max = T.alloc_fragment([G], accum_dtype)
scores_max_prev = T.alloc_fragment([G], accum_dtype)
scores_scale = T.alloc_fragment([G], accum_dtype)
scores_sum = T.alloc_fragment([G], accum_dtype)
logsum = T.alloc_fragment([G], accum_dtype)
i_v, i_t, i_bh = bx, by, bz
i_b, i_h = i_bh // heads, i_bh % heads
T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Pipelined(NS, num_stages=num_stages):
i_s = BlockIndices[i_b, i_t, i_h, i]
if i_s <= i_t:
MMA0(K, Q_shared, K_shared, acc_s, i_s, i_b, i_h, i_t)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, i_s, i_b, i_h)
for i, j in T.Parallel(G, BV):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[i_b, i_t, i_h * G:(i_h + 1) * G, :])
return main
def kernel(block_S, block_T, num_stages, threads):
return kernel_func(block_S, block_T, num_stages, threads)
return kernel(block_S, block_T, num_stages, threads)
if __name__ == "__main__":
B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 1, 64, 4, 64, 32, 16, 64, torch.float16, None
program = native_sparse_attention(
batch=B,
heads=HQ,
seq_len=SEQ_LEN,
dim=D,
is_causal=True,
scale=scale,
groups=HQ // H,
selected_blocks=S,
)
kernel = tilelang.compile(program, out_idx=[4])
Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda')
block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device='cuda')
for b in range(B):
for t in range(SEQ_LEN):
for h in range(H):
i_i = torch.randperm(max(1, (t // block_size)))[:S]
block_indices[b, t, h, :len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (B, SEQ_LEN, H), device='cuda')
out = kernel(Q, K, V, block_indices.to(torch.int32))
print(out)
ref = naive_nsa(
q=Q,
k=K,
v=V,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
scale=scale)
print(ref)
torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa
import torch
from typing import Optional
import torch
import triton
import triton.language as tl
from einops import rearrange
from fla.ops.common.utils import (prepare_chunk_indices, prepare_lens, prepare_token_indices)
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
from reference import naive_nsa
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16]],
key=['BS', 'BK', 'BV'],
)
@triton.jit
def parallel_nsa_fwd_kernel(
q,
k,
v,
o,
lse,
scale,
block_indices,
T,
H: tl.constexpr,
HQ: tl.constexpr,
G: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
S: tl.constexpr,
BS: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
):
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
bos, eos = i_b * T, i_b * T + T
k += (bos * H + i_h) * K
v += (bos * H + i_h) * V
block_indices += (bos + i_t) * H * S + i_h * S
NS = S
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK),
(1, 0))
p_o = tl.make_block_ptr(o + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV),
(1, 0))
p_lse = lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# the Q block is kept in the shared memory throughout the whole kernel
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [G, BV]
b_o = tl.zeros([G, BV], dtype=tl.float32)
b_m = tl.full([G], float('-inf'), dtype=tl.float32)
b_acc = tl.zeros([G], dtype=tl.float32)
for i in range(NS):
i_s = tl.load(block_indices + i).to(tl.int32) * BS
if i_s <= i_t:
p_k = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
# [BK, BS]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BS, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [G, BS]
b_s = tl.dot(b_q, b_k)
b_s = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s, float('-inf'))
# [G]
b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
b_r = tl.exp(b_mp - b_m)
# [G, BS]
b_p = tl.exp(b_s - b_m[:, None])
# [G]
b_acc = b_acc * b_r + tl.sum(b_p, 1)
# [G, BV]
b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
b_mp = b_m
b_o = b_o / b_acc[:, None]
b_m += tl.log(b_acc)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_lse, b_m.to(p_lse.dtype.element_ty))
def parallel_nsa_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.Tensor,
block_size: int,
scale: float,
):
B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
HQ = q.shape[2]
G = HQ // H
BS = block_size
if torch.cuda.get_device_capability()[0] >= 9:
BK = min(256, triton.next_power_of_2(K))
BV = min(256, triton.next_power_of_2(V))
else:
BK = min(128, triton.next_power_of_2(K))
BV = min(128, triton.next_power_of_2(V))
NK = triton.cdiv(K, BK)
NV = triton.cdiv(V, BV)
assert NK == 1, "The key dimension can not be larger than 256"
grid = (NV, T, B * H)
o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
lse = torch.empty(B, T, HQ, dtype=torch.float32, device=q.device)
print("grid", grid)
parallel_nsa_fwd_kernel[grid](
q=q,
k=k,
v=v,
o=o,
lse=lse,
scale=scale,
block_indices=block_indices,
H=H,
HQ=HQ,
G=G,
T=T,
K=K,
V=V,
S=S,
BS=BS,
BK=BK,
BV=BV,
)
return o, lse
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
def forward(ctx, q, k, v, block_indices, block_size, scale, offsets):
ctx.dtype = q.dtype
# 2-d sequence indices denoting the offsets of tokens in each sequence
# for example, if the passed `offsets` is [0, 2, 6],
# then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o, lse = parallel_nsa_fwd(
q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale)
ctx.save_for_backward(q, k, v, o, lse)
ctx.block_indices = block_indices
ctx.block_size = block_size
ctx.scale = scale
return o.to(q.dtype)
def parallel_nsa(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.LongTensor,
block_size: int = 64,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the number of selected blocks for each query token, which is set to 16 in the paper.
block_size (int):
Selected block size. Default: 64.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1]**-0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'),
(q, k, v, block_indices))
o = ParallelNSAFunction.apply(q, k, v, block_indices, block_size, scale, cu_seqlens)
if head_first:
o = rearrange(o, 'b t h d -> b h t d')
return o
if __name__ == "__main__":
B, T, H, HQ, D, S, block_size, dtype, scale = 1, 64, 1, 16, 32, 1, 64, torch.float16, 0.1
q = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
k = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
v = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
do = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda')
block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='cuda')
for b in range(B):
for t in range(T):
for h in range(H):
i_i = torch.randperm(max(1, (t // block_size)))[:S]
block_indices[b, t, h, :len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (B, T, H), device='cuda')
ref = naive_nsa(
q=q,
k=k,
v=v,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
scale=scale)
# print(ref)
tri = parallel_nsa(
q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale)
# print(tri)
torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2)
# import flash_attn
# # gqa
# o_gqa = flash_attn.flash_attn_func(
# q,
# k,
# v,
# softmax_scale=scale,
# )
# print(o_gqa)
# torch.testing.assert_close(o_gqa, tri, atol=1e-2, rtol=1e-2)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa
from typing import Optional
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
def naive_nsa(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: torch.LongTensor,
block_size: int = 64,
scale: Optional[float] = None,
head_first: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (torch.LongTensor):
Block counts of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
block_size (int):
Selected block size. Default: 64.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1]**-0.5
if cu_seqlens is not None:
if head_first:
raise RuntimeError(
"Sequences with variable lengths are not supported for head-first mode")
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'),
(q, k, v, block_indices))
block_counts = rearrange(block_counts, 'b h t -> b t h')
dtype = q.dtype
G = q.shape[2] // k.shape[2]
BS = block_size
S = block_indices.shape[-1]
k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices))
block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G)
c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device)
q, k, v = map(lambda x: x.float(), (q, k, v))
o = torch.zeros_like(v)
varlen = True
if cu_seqlens is None:
varlen = False
B, T = q.shape[:2]
cu_seqlens = torch.cat(
[block_indices.new_tensor(range(0, B * T, T)),
block_indices.new_tensor([B * T])])
for i in range(len(cu_seqlens) - 1):
if not varlen:
q_b, k_b, v_b, i_b, s_b = q[i], k[i], v[i], block_indices[i], block_counts[i]
else:
T = cu_seqlens[i + 1] - cu_seqlens[i]
q_b, k_b, v_b, i_b, s_b = map(lambda x: x[0][cu_seqlens[i]:cu_seqlens[i + 1]],
(q, k, v, block_indices, block_counts))
i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
# [T, S*BS, HQ]
i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2)
for i_q in range(T):
# [HQ, D]
q_i = q_b[i_q] * scale
# [S*BS, HQ]
i_i = i_b[i_q]
# [1, HQ]
s_i = s_b[i_q]
# [S*BS, HQ, -1]
k_i, v_i = map(
lambda x: x.gather(
0,
i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b))
# [S*BS, HQ]
attn = torch.einsum('h d, n h d -> n h', q_i, k_i).masked_fill((i_i > i_q) | (c >= s_i),
float('-inf')).softmax(0)
if not varlen:
o[i, i_q] = torch.einsum('n h, n h v -> h v', attn, v_i)
else:
o[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn, v_i)
if head_first:
o = rearrange(o, 'b t h d -> b h t d')
return o.to(dtype)
git+https://github.com/fla-org/flash-linear-attention
\ No newline at end of file
......@@ -125,3 +125,5 @@ from . import (
from .engine import lower # noqa: F401
from .version import __version__ # noqa: F401
from .math import * # noqa: F403
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
def next_power_of_2(x: int) -> int:
return 1 << (x - 1).bit_length()
def cdiv(a: int, b: int) -> int:
return (a + b - 1) // b
......@@ -80,7 +80,7 @@ def torch_assert_close(tensor_a,
return True
def set_random_seed(seed: int) -> None:
def set_random_seed(seed: int = 42) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment