Commit 6891d3ec authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Examples] Implement NSA Backward kernels (#180)


* Update native sparse attention example with scale parameter handling

- Add scale parameter processing in native_sparse_attention function
- Modify example script to include custom scale value
- Update function calls to pass scale parameter
- Enhance flexibility of sparse attention implementation

* Refactor Triton Native Sparse Attention Example

- Improve code formatting and readability in example_triton_nsa_bwd.py
- Standardize function and parameter alignment
- Remove unnecessary whitespaces and optimize imports
- Enhance code style consistency with previous commits
parent c39e540a
This diff is collapsed.
......@@ -19,6 +19,9 @@ def native_sparse_attention(batch,
selected_blocks=16):
if scale is None:
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
else:
scale = scale * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
......@@ -123,7 +126,7 @@ def native_sparse_attention(batch,
if __name__ == "__main__":
B, SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 32, 1, 32, torch.float16
B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1
program = native_sparse_attention(
batch=B,
......@@ -134,6 +137,7 @@ if __name__ == "__main__":
block_size=block_size,
groups=HQ // H,
selected_blocks=S,
scale=scale,
)
kernel = tilelang.compile(program, out_idx=-1)
torch.random.manual_seed(0)
......@@ -163,7 +167,9 @@ if __name__ == "__main__":
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size)
block_size=block_size,
scale=scale,
)
print("out", out)
print("ref", ref)
......
This diff is collapsed.
......@@ -5,6 +5,7 @@
from typing import Union, List, Optional
from tvm import tir
from tvm.script import tir as T
import tvm.ir
def region(buffer: tir.BufferLoad, access_type: str, *args: tir.PrimExpr):
......@@ -33,6 +34,8 @@ def copy(
dst: Union[tir.Buffer, tir.BufferLoad],
coalesced_width: Optional[int] = None,
):
if isinstance(src, tir.Buffer) and isinstance(dst, tir.Buffer):
tvm.ir.assert_structural_equal(src.shape, dst.shape)
def get_extent(data):
if isinstance(data, tir.Buffer):
......@@ -44,8 +47,7 @@ def copy(
src_extent = get_extent(src)
dst_extent = get_extent(dst)
# if src_extent and dst_extent:
# ir.assert_structural_equal(src_extent, dst_extent)
if src_extent:
extent = src_extent
elif dst_extent:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment