"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "4dc9eb9385f1cb3df07f01b35c19a02963ddcfa8"
Unverified Commit 37051417 authored by Wenhao Xie's avatar Wenhao Xie Committed by GitHub
Browse files

[Example] Add vertical slash sparse attention pattern (#762)

* upd sparse attn

* lint

* rename

* update test file

* update benchmark

* lint

* update benchmark
parent 1774a1aa
# Performance Benchmark
## Hardware & Environment
- **Hardware**: NVIDIA H100 PCIe
- **CUDA version**: 12.8.1
- **PyTorch Version**: 2.7.1+cu128
- **Triton Version**: 3.3.1
## Performance Results
BATCH_SIZE=1, HEAD=1, DIM=64
| SEQ_LEN | VS_LIST | Triton Time | TileLang Time | Speedup |
|---------|--------------|-------------|---------------|---------|
| 8192 | [1000, 200] | 0.168 ms | 0.105 ms | 1.60x |
| 8192 | [1000, 600] | 0.207 ms | 0.119 ms | 1.74x |
| 8192 | [800, 600] | 0.207 ms | 0.122 ms | 1.70x |
| | | | | |
| 16384 | [1000, 200] | 0.261 ms | 0.167 ms | 1.56x |
| 16384 | [1000, 600] | 0.419 ms | 0.258 ms | 1.62x |
| 16384 | [800, 600] | 0.422 ms | 0.255 ms | 1.65x |
| | | | | |
| 32768 | [1000, 200] | 0.374 ms | 0.248 ms | 1.51x |
| 32768 | [1000, 600] | 0.823 ms | 0.554 ms | 1.49x |
| 32768 | [800, 600] | 0.826 ms | 0.558 ms | 1.48x |
| | | | | |
| 65536 | [1000, 200] | 0.637 ms | 0.524 ms | 1.22x |
| 65536 | [1000, 600] | 1.758 ms | 1.501 ms | 1.17x |
| 65536 | [800, 600] | 1.783 ms | 1.489 ms | 1.20x |
This diff is collapsed.
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include "torch/extension.h"
#include <vector>
std::vector<at::Tensor> convert_vertical_slash_indexes(
torch::Tensor seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int context_size, int block_size_M, int block_size_N);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("convert_vertical_slash_indexes", &convert_vertical_slash_indexes,
"dynamic sparse index function");
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <assert.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include <torch/extension.h>
#include <cuda.h>
__device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) {
for (int idx = range_start; idx < range_end; idx += block_size) {
block_offset[block_count++] = idx;
}
}
__global__ void convert_vertical_slash_indexes_kernel(
const int* seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int N_HEADS,
int N_ROWS,
int BLOCK_SIZE_M,
int BLOCK_SIZE_N,
int NNZ_V,
int NNZ_S
) {
const int batch_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int group_idx = blockIdx.z;
int seqlen = seqlens[batch_idx];
int block_idx_m = group_idx * blockDim.x + threadIdx.x;
int start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= seqlen) {
return;
}
int end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
int tmp_col_cnt = 0, tmp_blk_cnt = 0;
int s = 0, v = 0;
int v_idx = vertical_indexes[v++];
int s_idx = slash_indexes[s++];
while (s_idx >= end_m) {
s_idx = slash_indexes[s++];
}
s_idx = max(end_m - s_idx, BLOCK_SIZE_M);
int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
while (1) {
if (v_idx < range_end) {
if (v_idx < range_start) {
column_index[tmp_col_cnt++] = v_idx;
}
if (v < NNZ_V) {
v_idx = vertical_indexes[v++];
} else {
v_idx = end_m + BLOCK_SIZE_M;
}
} else {
if (s < NNZ_S) {
s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M);
} else {
save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
break;
}
if (s_idx > range_end + BLOCK_SIZE_M) {
save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
range_start = s_idx - BLOCK_SIZE_M;
range_end = s_idx;
} else if (s_idx > range_end) {
range_end += BLOCK_SIZE_M;
}
}
}
block_count[0] = tmp_blk_cnt;
column_count[0] = tmp_col_cnt;
}
void convert_vertical_slash_indexes_64x64(
const int* seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int BATCH_SIZE,
int N_HEADS,
int N_ROWS,
int NNZ_V,
int NNZ_S
) {
const int BLOCK_SIZE_M = 64;
const int BLOCK_SIZE_N = 64;
const int N_THREADS = 64;
const dim3 dimBlock(N_THREADS);
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>(
seqlens, vertical_indexes, slash_indexes,
block_count, block_offset, column_count, column_index,
N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, NNZ_V, NNZ_S
);
}
std::vector<at::Tensor> convert_vertical_slash_indexes(
torch::Tensor seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int context_size,
int block_size_M,
int block_size_N
) {
assert(block_size_M == 64);
assert(block_size_N == 64);
cudaSetDevice(seqlens.get_device());
int batch_size = slash_indexes.size(0);
int num_heads = slash_indexes.size(1);
int nnz_slash = slash_indexes.size(2);
int nnz_vertical = vertical_indexes.size(2);
int num_rows = (context_size + block_size_M - 1) / block_size_M;
torch::Tensor block_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
torch::Tensor block_offset = torch::zeros({batch_size, num_heads, num_rows, nnz_slash}, seqlens.options());
torch::Tensor column_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
torch::Tensor column_index = torch::zeros({batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options());
convert_vertical_slash_indexes_64x64(
seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(),
slash_indexes.data_ptr<int>(),
block_count.data_ptr<int>(),
block_offset.data_ptr<int>(),
column_count.data_ptr<int>(),
column_index.data_ptr<int>(),
batch_size,
num_heads,
num_rows,
nnz_vertical,
nnz_slash
);
return { block_count, block_offset, column_count, column_index };
}
import tilelang.testing
import example_vertical_slash_sparse_attn
@tilelang.testing.requires_cuda
def test_vs_sparse_attn():
example_vertical_slash_sparse_attn.main()
if __name__ == "__main__":
tilelang.testing.main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment