Commit 6891d3ec authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Examples] Implement NSA Backward kernels (#180)


* Update native sparse attention example with scale parameter handling

- Add scale parameter processing in native_sparse_attention function
- Modify example script to include custom scale value
- Update function calls to pass scale parameter
- Enhance flexibility of sparse attention implementation

* Refactor Triton Native Sparse Attention Example

- Improve code formatting and readability in example_triton_nsa_bwd.py
- Standardize function and parameter alignment
- Remove unnecessary whitespaces and optimize imports
- Enhance code style consistency with previous commits
parent c39e540a
This diff is collapsed.
...@@ -19,6 +19,9 @@ def native_sparse_attention(batch, ...@@ -19,6 +19,9 @@ def native_sparse_attention(batch,
selected_blocks=16): selected_blocks=16):
if scale is None: if scale is None:
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
else:
scale = scale * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim] q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim] kv_shape = [batch, seq_len, head_kv, dim]
...@@ -123,7 +126,7 @@ def native_sparse_attention(batch, ...@@ -123,7 +126,7 @@ def native_sparse_attention(batch,
if __name__ == "__main__": if __name__ == "__main__":
B, SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 32, 1, 32, torch.float16 B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1
program = native_sparse_attention( program = native_sparse_attention(
batch=B, batch=B,
...@@ -134,6 +137,7 @@ if __name__ == "__main__": ...@@ -134,6 +137,7 @@ if __name__ == "__main__":
block_size=block_size, block_size=block_size,
groups=HQ // H, groups=HQ // H,
selected_blocks=S, selected_blocks=S,
scale=scale,
) )
kernel = tilelang.compile(program, out_idx=-1) kernel = tilelang.compile(program, out_idx=-1)
torch.random.manual_seed(0) torch.random.manual_seed(0)
...@@ -163,7 +167,9 @@ if __name__ == "__main__": ...@@ -163,7 +167,9 @@ if __name__ == "__main__":
g_swa=g_swa, g_swa=g_swa,
block_indices=block_indices, block_indices=block_indices,
block_counts=block_counts, block_counts=block_counts,
block_size=block_size) block_size=block_size,
scale=scale,
)
print("out", out) print("out", out)
print("ref", ref) print("ref", ref)
......
This diff is collapsed.
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from typing import Union, List, Optional from typing import Union, List, Optional
from tvm import tir from tvm import tir
from tvm.script import tir as T from tvm.script import tir as T
import tvm.ir
def region(buffer: tir.BufferLoad, access_type: str, *args: tir.PrimExpr): def region(buffer: tir.BufferLoad, access_type: str, *args: tir.PrimExpr):
...@@ -33,6 +34,8 @@ def copy( ...@@ -33,6 +34,8 @@ def copy(
dst: Union[tir.Buffer, tir.BufferLoad], dst: Union[tir.Buffer, tir.BufferLoad],
coalesced_width: Optional[int] = None, coalesced_width: Optional[int] = None,
): ):
if isinstance(src, tir.Buffer) and isinstance(dst, tir.Buffer):
tvm.ir.assert_structural_equal(src.shape, dst.shape)
def get_extent(data): def get_extent(data):
if isinstance(data, tir.Buffer): if isinstance(data, tir.Buffer):
...@@ -44,8 +47,7 @@ def copy( ...@@ -44,8 +47,7 @@ def copy(
src_extent = get_extent(src) src_extent = get_extent(src)
dst_extent = get_extent(dst) dst_extent = get_extent(dst)
# if src_extent and dst_extent:
# ir.assert_structural_equal(src_extent, dst_extent)
if src_extent: if src_extent:
extent = src_extent extent = src_extent
elif dst_extent: elif dst_extent:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment