Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
467
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
627 additions
and
647 deletions
+627
-647
examples/hadamard_transform/example_hadamard.py
examples/hadamard_transform/example_hadamard.py
+15
-20
examples/lazy_jit/lazyjit.en.ipynb
examples/lazy_jit/lazyjit.en.ipynb
+41
-31
examples/lazy_jit/lazyjit.zh.ipynb
examples/lazy_jit/lazyjit.zh.ipynb
+41
-31
examples/linear_attention/example_linear_attn_bwd.py
examples/linear_attention/example_linear_attn_bwd.py
+55
-74
examples/linear_attention/example_linear_attn_fwd.py
examples/linear_attention/example_linear_attn_fwd.py
+40
-46
examples/linear_attention/example_mamba_chunk_scan.py
examples/linear_attention/example_mamba_chunk_scan.py
+110
-79
examples/linear_attention/example_mamba_chunk_state.py
examples/linear_attention/example_mamba_chunk_state.py
+57
-62
examples/linear_attention/example_retention_fwd.py
examples/linear_attention/example_retention_fwd.py
+22
-27
examples/minference/example_vertical_slash_sparse_attn.py
examples/minference/example_vertical_slash_sparse_attn.py
+123
-100
examples/norm/rms_norm.py
examples/norm/rms_norm.py
+2
-2
examples/norm/test_rms_norm.py
examples/norm/test_rms_norm.py
+2
-2
examples/online_softmax/online_softmax.py
examples/online_softmax/online_softmax.py
+6
-6
examples/plot_layout/fragment_mfma_load_a.py
examples/plot_layout/fragment_mfma_load_a.py
+7
-13
examples/plot_layout/fragment_mma_load_a.py
examples/plot_layout/fragment_mma_load_a.py
+4
-7
examples/quickstart.py
examples/quickstart.py
+3
-4
examples/seer_attention/block_sparse_attn_tilelang.py
examples/seer_attention/block_sparse_attn_tilelang.py
+48
-64
examples/seer_attention/block_sparse_attn_triton.py
examples/seer_attention/block_sparse_attn_triton.py
+24
-46
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
...s/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
+14
-16
examples/topk/example_topk.py
examples/topk/example_topk.py
+5
-9
examples/visual_layout_inference/visual_layout_inference.py
examples/visual_layout_inference/visual_layout_inference.py
+8
-8
No files found.
examples/hadamard_transform/example_hadamard.py
View file @
29051439
...
...
@@ -17,7 +17,7 @@ def is_pow_of_2(n):
def
hadamard
(
b
,
n
,
dtype
):
assert
is_pow_of_2
(
n
),
"n must be a power of 2"
assert
2
<=
n
<=
32768
,
"n must be in [2, 32768]"
elem_size
=
{
'
float32
'
:
4
,
'
float16
'
:
2
,
'
bfloat16
'
:
2
}[
dtype
]
elem_size
=
{
"
float32
"
:
4
,
"
float16
"
:
2
,
"
bfloat16
"
:
2
}[
dtype
]
logN
=
int
(
math
.
log2
(
n
))
threads
=
[
0
,
1
,
1
,
1
,
2
,
4
,
8
,
16
,
32
,
32
,
128
,
256
,
256
,
256
,
256
,
256
][
logN
]
...
...
@@ -40,23 +40,21 @@ def hadamard(b, n, dtype):
# print(f'{exchange_round=}')
@
T
.
macro
def
warp_shfl
(
local
:
T
.
Tensor
((
thread_elem
,),
dtype
),
buf
:
T
.
Tensor
((
thread_elem
,),
dtype
),
round
:
int
):
def
warp_shfl
(
local
:
T
.
Tensor
((
thread_elem
,),
dtype
),
buf
:
T
.
Tensor
((
thread_elem
,),
dtype
),
round
:
int
):
tx
=
T
.
get_thread_binding
(
0
)
for
i
in
T
.
serial
(
round
):
tx_stride
=
1
<<
i
another_tx
=
tx
^
tx_stride
sign
=
(
tx
>>
i
)
&
1
# get i-th lowest bit of tx, which determines the operation type for shared[tx, :]
sign
=
(
tx
>>
i
)
&
1
# get i-th lowest bit of tx, which determines the operation type for shared[tx, :]
for
j
in
T
.
Pipelined
(
thread_elem
,
num_stages
=
1
):
buf
[
j
]
=
T
.
tvm_warp_shuffle
(
0x
ffffffff
,
# mask of all threads
0x
FFFFFFFF
,
# mask of all threads
local
[
j
],
another_tx
%
warp_size
,
warp_size
,
warp_size
)
warp_size
,
)
local
[
j
]
=
T
.
if_then_else
(
sign
==
0
,
local
[
j
]
+
buf
[
j
],
buf
[
j
]
-
local
[
j
])
@
T
.
prim_func
...
...
@@ -78,10 +76,8 @@ def hadamard(b, n, dtype):
for
j
in
T
.
serial
(
chunknum
):
chunkbase
=
j
*
chunksize
for
k
in
T
.
serial
(
chunksize
//
2
):
local
[
chunkbase
+
k
]
=
local
[
chunkbase
+
k
]
+
local
[
chunkbase
+
k
+
chunksize
//
2
]
local
[
chunkbase
+
k
+
chunksize
//
2
]
=
local
[
chunkbase
+
k
]
-
2
*
local
[
chunkbase
+
k
+
chunksize
//
2
]
local
[
chunkbase
+
k
]
=
local
[
chunkbase
+
k
]
+
local
[
chunkbase
+
k
+
chunksize
//
2
]
local
[
chunkbase
+
k
+
chunksize
//
2
]
=
local
[
chunkbase
+
k
]
-
2
*
local
[
chunkbase
+
k
+
chunksize
//
2
]
# 3. Hadamard inside warp, n<=512
# In warp level, we rely on warp shuffle to exchange data inside each warp, without using shared memory
...
...
@@ -131,28 +127,27 @@ def ref_program(x: torch.Tensor):
assert
x
.
ndim
==
2
dim
=
x
.
shape
[
-
1
]
assert
is_pow_of_2
(
dim
)
return
F
.
linear
(
x
,
torch
.
tensor
(
scipy
.
linalg
.
hadamard
(
dim
,
dtype
=
float
),
dtype
=
x
.
dtype
,
device
=
x
.
device
))
return
F
.
linear
(
x
,
torch
.
tensor
(
scipy
.
linalg
.
hadamard
(
dim
,
dtype
=
float
),
dtype
=
x
.
dtype
,
device
=
x
.
device
))
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
64
,
help
=
'
Batch size
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
32768
,
help
=
'
Dimension
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
64
,
help
=
"
Batch size
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
32768
,
help
=
"
Dimension
"
)
args
=
parser
.
parse_args
()
B
,
D
=
args
.
batch
,
args
.
dim
x
=
torch
.
randn
((
B
,
D
),
device
=
'
cuda
'
)
kernel
=
hadamard
(
B
,
D
,
'
float32
'
)
x
=
torch
.
randn
((
B
,
D
),
device
=
"
cuda
"
)
kernel
=
hadamard
(
B
,
D
,
"
float32
"
)
y
=
kernel
(
x
)
y_ref
=
ref_program
(
x
)
torch
.
testing
.
assert_close
(
y
,
y_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
print
(
'
All tests passed.
'
)
print
(
"
All tests passed.
"
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Auto
)
latency
=
profiler
.
do_bench
(
warmup
=
100
)
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
main
()
examples/lazy_jit/lazyjit.en.ipynb
View file @
29051439
...
...
@@ -9,6 +9,7 @@
"source": [
"import sys\n",
"from pathlib import Path\n",
"\n",
"sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n",
"import tilelang\n",
"import torch\n",
...
...
@@ -61,7 +62,7 @@
" out_dtype: T.dtype = T.float32,\n",
" block_M: int = 128,\n",
" block_N: int = 128,\n",
" block_K: int = 32\n",
" block_K: int = 32
,
\n",
"):\n",
" M, K = A.shape\n",
" K, N = B.shape\n",
...
...
@@ -94,8 +95,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm(A, B)\n",
"\n",
"# check output is correct\n",
...
...
@@ -118,8 +119,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 1024, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 1024, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm(A, B, block_M=64, block_N=64)"
]
},
...
...
@@ -218,8 +219,8 @@
"source": [
"@tilelang.lazy_jit\n",
"def gemm_dyn_K(\n",
" A: T.Tensor[[int, T.dyn[
'K'
]], T.float16], # noqa: F821\n",
" B: T.Tensor[[T.dyn[
'K'
], int], T.float16], # noqa: F821\n",
" A: T.Tensor[[int, T.dyn[
\"K\"
]], T.float16],
# noqa: F821\n",
" B: T.Tensor[[T.dyn[
\"K\"
], int], T.float16],
# noqa: F821\n",
"):\n",
" M, K = A.shape\n",
" K, N = B.shape\n",
...
...
@@ -265,8 +266,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm_dyn_K(A, B)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...
...
@@ -295,18 +296,17 @@
"source": [
"from typing import Any\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"def as_contingious(\n",
" A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]\n",
"):\n",
"def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):\n",
" M, N = A.shape\n",
" B = T.empty((M, N), A.dtype)\n",
" block_M = 128\n",
" block_N = 128\n",
" with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
" T.copy(\n",
" A[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N],\n",
" B[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N]\n",
" A[bx * block_M
: (bx + 1) * block_M, by * block_N
: (by + 1) * block_N],\n",
" B[bx * block_M
: (bx + 1) * block_M, by * block_N
: (by + 1) * block_N]
,
\n",
" )\n",
" return B"
]
...
...
@@ -318,7 +318,7 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 1024, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 1024, device=
\"
cuda
\"
)\n",
"B = as_contingious(A[::2, ::2])\n",
"B_ref = A[::2, ::2].contiguous()\n",
"torch.testing.assert_close(B, B_ref)"
...
...
@@ -370,8 +370,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm_ptr(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...
...
@@ -416,8 +416,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...
...
@@ -496,18 +496,20 @@
"source": [
"from itertools import product\n",
"\n",
"\n",
"def get_configs():\n",
" return [\n",
" {\n",
"
'A'
: T.Tensor((1024, 1024), T.float32),\n",
"
'B'
: T.Tensor((1024, 1024), T.float32),\n",
"
'
block_M
'
: block_M,\n",
"
'
block_N
'
: block_N,\n",
"
'
block_K
'
: block_K,\n",
"
\"A\"
: T.Tensor((1024, 1024), T.float32),\n",
"
\"B\"
: T.Tensor((1024, 1024), T.float32),\n",
"
\"
block_M
\"
: block_M,\n",
"
\"
block_N
\"
: block_N,\n",
"
\"
block_K
\"
: block_K,\n",
" }\n",
" for block_M, block_N, block_K in product([32, 64], repeat=3)\n",
" ]\n",
"\n",
"\n",
"gemm.par_compile(get_configs())"
]
},
...
...
@@ -581,6 +583,7 @@
"def macro_with_ref(x: T.Ref):\n",
" x = 1 # noqa: F841\n",
"\n",
"\n",
"@T.prim_func\n",
"def foo(x: T.Tensor((2,))):\n",
" with T.Kernel(1) as _:\n",
...
...
@@ -591,6 +594,7 @@
" idx = T.alloc_var(T.int32, 0)\n",
" macro_with_ref(x[idx])\n",
"\n",
"\n",
"foo"
]
},
...
...
@@ -616,7 +620,7 @@
" A: T.Tensor[[T.dyn], Any],\n",
" fn,\n",
"):\n",
" N, = A.shape\n",
"
(
N,
)
= A.shape\n",
" B = T.empty((N,), dtype=A.dtype)\n",
" block_N = 128\n",
" with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n",
...
...
@@ -624,6 +628,8 @@
" idx = bx * block_N + i\n",
" B[idx] = fn(A[idx])\n",
" return B\n",
"\n",
"\n",
"@T.macro\n",
"def add_one(x):\n",
" return x + 1"
...
...
@@ -636,7 +642,7 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, device=
'
cuda
'
)\n",
"A = torch.randn(1024, device=
\"
cuda
\"
)\n",
"B = element_wise(A, add_one)\n",
"B_ref = A + 1\n",
"torch.testing.assert_close(B, B_ref)"
...
...
@@ -670,10 +676,11 @@
" var = var * 3 + 1\n",
" n31(x * 3 + 1, var)\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"def foo(A: T.Tensor[[1], T.int32], n: int):\n",
" with T.Kernel(1) as _:\n",
" n31(n, A[0])
\n
"
" n31(n, A[0])"
]
},
{
...
...
@@ -694,7 +701,7 @@
}
],
"source": [
"A = torch.tensor([100], dtype=torch.int32, device=
'
cuda
'
)\n",
"A = torch.tensor([100], dtype=torch.int32, device=
\"
cuda
\"
)\n",
"foo(A, 5)\n",
"A"
]
...
...
@@ -745,12 +752,15 @@
"def sincos(x):\n",
" return T.sin(x), T.cos(x)\n",
"\n",
"\n",
"@T.prim_func\n",
"def foo():\n",
" with T.Kernel(32) as x:\n",
" s, c = sincos(x)\n",
" a = s + c # noqa: F841\n",
" b = s - c # noqa: F841\n",
"\n",
"\n",
"foo"
]
}
...
...
examples/lazy_jit/lazyjit.zh.ipynb
View file @
29051439
...
...
@@ -9,6 +9,7 @@
"source": [
"import sys\n",
"from pathlib import Path\n",
"\n",
"sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n",
"import tilelang\n",
"import torch\n",
...
...
@@ -61,7 +62,7 @@
" out_dtype: T.dtype = T.float32,\n",
" block_M: int = 128,\n",
" block_N: int = 128,\n",
" block_K: int = 32\n",
" block_K: int = 32
,
\n",
"):\n",
" M, K = A.shape\n",
" K, N = B.shape\n",
...
...
@@ -94,8 +95,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm(A, B)\n",
"\n",
"# check output is correct\n",
...
...
@@ -118,8 +119,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 1024, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 1024, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm(A, B, block_M=64, block_N=64)"
]
},
...
...
@@ -218,8 +219,8 @@
"source": [
"@tilelang.lazy_jit\n",
"def gemm_dyn_K(\n",
" A: T.Tensor[[int, T.dyn[
'K'
]], T.float16], # noqa: F821\n",
" B: T.Tensor[[T.dyn[
'K'
], int], T.float16], # noqa: F821\n",
" A: T.Tensor[[int, T.dyn[
\"K\"
]], T.float16],
# noqa: F821\n",
" B: T.Tensor[[T.dyn[
\"K\"
], int], T.float16],
# noqa: F821\n",
"):\n",
" M, K = A.shape\n",
" K, N = B.shape\n",
...
...
@@ -265,8 +266,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm_dyn_K(A, B)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...
...
@@ -295,18 +296,17 @@
"source": [
"from typing import Any\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"def as_contingious(\n",
" A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]\n",
"):\n",
"def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):\n",
" M, N = A.shape\n",
" B = T.empty((M, N), A.dtype)\n",
" block_M = 128\n",
" block_N = 128\n",
" with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
" T.copy(\n",
" A[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N],\n",
" B[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N]\n",
" A[bx * block_M
: (bx + 1) * block_M, by * block_N
: (by + 1) * block_N],\n",
" B[bx * block_M
: (bx + 1) * block_M, by * block_N
: (by + 1) * block_N]
,
\n",
" )\n",
" return B"
]
...
...
@@ -318,7 +318,7 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 1024, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 1024, device=
\"
cuda
\"
)\n",
"B = as_contingious(A[::2, ::2])\n",
"B_ref = A[::2, ::2].contiguous()\n",
"torch.testing.assert_close(B, B_ref)"
...
...
@@ -370,8 +370,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm_ptr(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...
...
@@ -416,8 +416,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...
...
@@ -496,18 +496,20 @@
"source": [
"from itertools import product\n",
"\n",
"\n",
"def get_configs():\n",
" return [\n",
" {\n",
"
'A'
: T.Tensor((1024, 1024), T.float32),\n",
"
'B'
: T.Tensor((1024, 1024), T.float32),\n",
"
'
block_M
'
: block_M,\n",
"
'
block_N
'
: block_N,\n",
"
'
block_K
'
: block_K,\n",
"
\"A\"
: T.Tensor((1024, 1024), T.float32),\n",
"
\"B\"
: T.Tensor((1024, 1024), T.float32),\n",
"
\"
block_M
\"
: block_M,\n",
"
\"
block_N
\"
: block_N,\n",
"
\"
block_K
\"
: block_K,\n",
" }\n",
" for block_M, block_N, block_K in product([32, 64], repeat=3)\n",
" ]\n",
"\n",
"\n",
"gemm.par_compile(get_configs())"
]
},
...
...
@@ -581,6 +583,7 @@
"def macro_with_ref(x: T.Ref):\n",
" x = 1 # noqa: F841\n",
"\n",
"\n",
"@T.prim_func\n",
"def foo(x: T.Tensor((2,))):\n",
" with T.Kernel(1) as _:\n",
...
...
@@ -591,6 +594,7 @@
" idx = T.alloc_var(T.int32, 0)\n",
" macro_with_ref(x[idx])\n",
"\n",
"\n",
"foo"
]
},
...
...
@@ -616,7 +620,7 @@
" A: T.Tensor[[T.dyn], Any],\n",
" fn,\n",
"):\n",
" N, = A.shape\n",
"
(
N,
)
= A.shape\n",
" B = T.empty((N,), dtype=A.dtype)\n",
" block_N = 128\n",
" with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n",
...
...
@@ -624,6 +628,8 @@
" idx = bx * block_N + i\n",
" B[idx] = fn(A[idx])\n",
" return B\n",
"\n",
"\n",
"@T.macro\n",
"def add_one(x):\n",
" return x + 1"
...
...
@@ -636,7 +642,7 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, device=
'
cuda
'
)\n",
"A = torch.randn(1024, device=
\"
cuda
\"
)\n",
"B = element_wise(A, add_one)\n",
"B_ref = A + 1\n",
"torch.testing.assert_close(B, B_ref)"
...
...
@@ -670,10 +676,11 @@
" var = var * 3 + 1\n",
" n31(x * 3 + 1, var)\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"def foo(A: T.Tensor[[1], T.int32], n: int):\n",
" with T.Kernel(1) as _:\n",
" n31(n, A[0])
\n
"
" n31(n, A[0])"
]
},
{
...
...
@@ -694,7 +701,7 @@
}
],
"source": [
"A = torch.tensor([100], dtype=torch.int32, device=
'
cuda
'
)\n",
"A = torch.tensor([100], dtype=torch.int32, device=
\"
cuda
\"
)\n",
"foo(A, 5)\n",
"A"
]
...
...
@@ -745,12 +752,15 @@
"def sincos(x):\n",
" return T.sin(x), T.cos(x)\n",
"\n",
"\n",
"@T.prim_func\n",
"def foo():\n",
" with T.Kernel(32) as x:\n",
" s, c = sincos(x)\n",
" a = s + c # noqa: F841\n",
" b = s - c # noqa: F841\n",
"\n",
"\n",
"foo"
]
}
...
...
examples/linear_attention/example_linear_attn_bwd.py
View file @
29051439
...
...
@@ -13,20 +13,20 @@ from typing import Optional, Tuple
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
}
)
def
tl_fused_chunk_bwd_kernel
(
B
,
S
,
H
,
DK
,
DV
,
dtype
:
str
=
'
float16
'
,
dtype
:
str
=
"
float16
"
,
scale
:
float
=
None
,
)
->
torch
.
Tensor
:
if
scale
is
None
:
scale
=
DK
**-
0.5
accum_dtype
=
'
float
'
accum_dtype
=
"
float
"
chunk_size
=
64
BK
=
BV
=
64
# Set to 128 can be faster, but has some numerical differences with FLA
...
...
@@ -66,11 +66,13 @@ def tl_fused_chunk_bwd_kernel(
dh
=
T
.
alloc_fragment
([
BK
,
BV
],
accum_dtype
)
dh_shared
=
T
.
alloc_shared
([
BK
,
BV
],
dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
dq_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dq_shared
),
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
)
})
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
}
)
T
.
use_swizzle
(
10
)
T
.
clear
(
h
)
...
...
@@ -78,10 +80,9 @@ def tl_fused_chunk_bwd_kernel(
# Calculate dQ
for
i
in
T
.
Pipelined
(
0
,
NT
):
T
.
copy
(
K
[
i_b
,
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
V
[
i_b
,
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
],
v
)
T
.
copy
(
dO
[
i_b
,
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
],
do
)
T
.
copy
(
K
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
V
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
v
)
T
.
copy
(
dO
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
do
)
T
.
gemm
(
do
,
v
,
ds
,
transpose_B
=
True
,
clear_accum
=
True
)
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
chunk_size
):
...
...
@@ -94,29 +95,19 @@ def tl_fused_chunk_bwd_kernel(
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
BK
):
dq
[
row
,
col
]
*=
scale
T
.
copy
(
dq
,
dq_shared
)
T
.
atomic_add
(
dQ
[
i_b
,
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:(
i_k
+
1
)
*
BK
],
dq_shared
)
T
.
atomic_add
(
dQ
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
dq_shared
)
# Calculate dK, dV (reversely)
for
i
in
T
.
Pipelined
(
1
,
NT
+
1
):
start
=
NT
-
i
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
BK
):
q
[
row
,
col
]
=
Q
[
i_b
,
start
*
chunk_size
+
row
,
i_h
,
i_k
*
BK
+
col
]
*
scale
T
.
copy
(
K
[
i_b
,
start
*
chunk_size
:(
start
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
V
[
i_b
,
start
*
chunk_size
:(
start
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
],
v
)
T
.
copy
(
dO
[
i_b
,
start
*
chunk_size
:(
start
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
],
do
)
T
.
copy
(
K
[
i_b
,
start
*
chunk_size
:
(
start
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
V
[
i_b
,
start
*
chunk_size
:
(
start
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
v
)
T
.
copy
(
dO
[
i_b
,
start
*
chunk_size
:
(
start
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
do
)
# Calculate dk
T
.
gemm
(
v
,
do
,
ds
,
transpose_B
=
True
,
clear_accum
=
True
)
# ds here actually means `s`, but we simply reuse the buffer `ds`
T
.
gemm
(
v
,
do
,
ds
,
transpose_B
=
True
,
clear_accum
=
True
)
# ds here actually means `s`, but we simply reuse the buffer `ds`
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
chunk_size
):
ds_shared
[
row
,
col
]
=
T
.
if_then_else
(
row
<=
col
,
ds
[
row
,
col
],
0
)
T
.
gemm
(
ds_shared
,
q
,
dk
,
clear_accum
=
True
)
...
...
@@ -134,13 +125,9 @@ def tl_fused_chunk_bwd_kernel(
T
.
gemm
(
q
,
do
,
dh
,
transpose_A
=
True
)
T
.
copy
(
dk
,
dk_shared
)
T
.
atomic_add
(
dK
[
i_b
,
start
*
chunk_size
:(
start
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:(
i_k
+
1
)
*
BK
],
dk_shared
)
T
.
atomic_add
(
dK
[
i_b
,
start
*
chunk_size
:
(
start
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
dk_shared
)
T
.
copy
(
dv
,
dv_shared
)
T
.
atomic_add
(
dV
[
i_b
,
start
*
chunk_size
:(
start
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
],
dv_shared
)
T
.
atomic_add
(
dV
[
i_b
,
start
*
chunk_size
:
(
start
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
dv_shared
)
return
fused_chunk_linear_attn_bwd
...
...
@@ -155,34 +142,31 @@ def tl_fused_chunk_bwd(Q, K, V, dO):
return
dQ
.
to
(
torch
.
float16
),
dK
.
to
(
torch
.
float16
),
dV
.
to
(
torch
.
float16
)
def
ref_program
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
scale
:
Optional
[
float
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
ref_program
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
scale
:
Optional
[
float
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
if
scale
is
None
:
scale
=
q
.
shape
[
-
1
]
**-
0.5
scale
=
q
.
shape
[
-
1
]
**
-
0.5
chunk_size
=
64
q
=
rearrange
(
q
,
'
b (n c) h d -> b h n c d
'
,
c
=
chunk_size
)
*
scale
k
=
rearrange
(
k
,
'
b (n c) h d -> b h n c d
'
,
c
=
chunk_size
)
v
=
rearrange
(
v
,
'
b (n c) h d -> b h n c d
'
,
c
=
chunk_size
)
q
=
rearrange
(
q
,
"
b (n c) h d -> b h n c d
"
,
c
=
chunk_size
)
*
scale
k
=
rearrange
(
k
,
"
b (n c) h d -> b h n c d
"
,
c
=
chunk_size
)
v
=
rearrange
(
v
,
"
b (n c) h d -> b h n c d
"
,
c
=
chunk_size
)
kv
=
k
.
transpose
(
-
1
,
-
2
)
@
v
kv
=
kv
.
cumsum
(
2
)
h
=
kv
[:,
:,
-
1
,
:,
:]
kv
=
torch
.
cat
([
torch
.
zeros_like
(
kv
[:,
:,
:
1
]),
kv
[:,
:,
:
-
1
]],
dim
=
2
)
inter
=
q
@
kv
intra
=
(
(
q
@
k
.
transpose
(
-
1
,
-
2
)).
masked_fill_
(
torch
.
triu
(
torch
.
ones
(
chunk_size
,
chunk_size
,
dtype
=
bool
,
device
=
q
.
device
),
diagonal
=
1
),
0
)
)
@
v
intra
=
(
(
q
@
k
.
transpose
(
-
1
,
-
2
)).
masked_fill_
(
torch
.
triu
(
torch
.
ones
(
chunk_size
,
chunk_size
,
dtype
=
bool
,
device
=
q
.
device
),
diagonal
=
1
),
0
)
)
@
v
o
=
inter
+
intra
return
rearrange
(
o
,
'
b h n c d -> b (n c) h d
'
),
h
return
rearrange
(
o
,
"
b h n c d -> b (n c) h d
"
),
h
def
main
(
B
=
1
,
S
=
1024
,
H
=
16
,
D
=
128
):
q
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
,
requires_grad
=
True
)
k
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
,
requires_grad
=
True
)
v
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
,
requires_grad
=
True
)
do
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
,
requires_grad
=
True
)
k
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
,
requires_grad
=
True
)
v
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
,
requires_grad
=
True
)
do
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
# qk norm is necessary for linear attn
q
=
l2norm_fwd
(
q
)[
0
].
requires_grad_
(
True
)
...
...
@@ -193,30 +177,27 @@ def main(B=1, S=1024, H=16, D=128):
o_ref
,
_
=
ref_program
(
q
,
k
,
v
)
o_ref
.
backward
(
do
,
retain_graph
=
True
)
assert
torch
.
allclose
(
dq
,
q
.
grad
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
'dq max err:
{
(
dq
-
q
.
grad
).
abs
().
max
()
}
'
assert
torch
.
allclose
(
dk
,
k
.
grad
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
'dk max err:
{
(
dk
-
k
.
grad
).
abs
().
max
()
}
'
assert
torch
.
allclose
(
dv
,
v
.
grad
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
'dv max err:
{
(
dv
-
v
.
grad
).
abs
().
max
()
}
'
print
(
'Passed all tests!✅'
)
assert
torch
.
allclose
(
dq
,
q
.
grad
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"dq max err:
{
(
dq
-
q
.
grad
).
abs
().
max
()
}
"
assert
torch
.
allclose
(
dk
,
k
.
grad
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"dk max err:
{
(
dk
-
k
.
grad
).
abs
().
max
()
}
"
assert
torch
.
allclose
(
dv
,
v
.
grad
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"dv max err:
{
(
dv
-
v
.
grad
).
abs
().
max
()
}
"
print
(
"Passed all tests!✅"
)
# Benchmark
q
.
grad
=
k
.
grad
=
v
.
grad
=
None
o_ref
,
_
=
fused_chunk_linear_attn
(
q
,
k
,
v
,
output_final_state
=
True
,
normalize
=
False
)
t1
=
do_bench
(
lambda
:
o_ref
.
backward
(
do
,
retain_graph
=
True
),
backend
=
'
cupti
'
)
t2
=
do_bench
(
lambda
:
tl_fused_chunk_bwd
(
q
,
k
,
v
,
do
),
backend
=
'
cupti
'
)
print
(
f
'
Triton latency:
{
t1
:.
3
f
}
ms
'
)
print
(
f
'
TileLang latency:
{
t2
:.
3
f
}
ms
'
)
print
(
f
'
Speedup:
{
t1
/
t2
:.
3
f
}
x
'
)
t1
=
do_bench
(
lambda
:
o_ref
.
backward
(
do
,
retain_graph
=
True
),
backend
=
"
cupti
"
)
t2
=
do_bench
(
lambda
:
tl_fused_chunk_bwd
(
q
,
k
,
v
,
do
),
backend
=
"
cupti
"
)
print
(
f
"
Triton latency:
{
t1
:.
3
f
}
ms
"
)
print
(
f
"
TileLang latency:
{
t2
:.
3
f
}
ms
"
)
print
(
f
"
Speedup:
{
t1
/
t2
:.
3
f
}
x
"
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--B
'
,
type
=
int
,
default
=
8
,
help
=
'
Batch size
'
)
parser
.
add_argument
(
'
--S
'
,
type
=
int
,
default
=
1024
,
help
=
'
Seq len
'
)
parser
.
add_argument
(
'
--H
'
,
type
=
int
,
default
=
32
,
help
=
'
Num heads
'
)
parser
.
add_argument
(
'
--D
'
,
type
=
int
,
default
=
128
,
help
=
'
Head dim
'
)
parser
.
add_argument
(
"
--B
"
,
type
=
int
,
default
=
8
,
help
=
"
Batch size
"
)
parser
.
add_argument
(
"
--S
"
,
type
=
int
,
default
=
1024
,
help
=
"
Seq len
"
)
parser
.
add_argument
(
"
--H
"
,
type
=
int
,
default
=
32
,
help
=
"
Num heads
"
)
parser
.
add_argument
(
"
--D
"
,
type
=
int
,
default
=
128
,
help
=
"
Head dim
"
)
args
=
parser
.
parse_args
()
main
(
args
.
B
,
args
.
S
,
args
.
H
,
args
.
D
)
examples/linear_attention/example_linear_attn_fwd.py
View file @
29051439
...
...
@@ -14,20 +14,20 @@ from typing import Optional, Tuple
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
def
tl_fused_chunk_fwd_kernel
(
B
,
S
,
H
,
DK
,
DV
,
dtype
:
str
=
'
float16
'
,
dtype
:
str
=
"
float16
"
,
scale
:
float
=
None
,
)
->
torch
.
Tensor
:
if
scale
is
None
:
scale
=
DK
**-
0.5
accum_dtype
=
'
float
'
accum_dtype
=
"
float
"
chunk_size
=
64
BK
=
BV
=
64
# Set to 128 can be faster, but has some numerical differences with FLA
...
...
@@ -42,7 +42,8 @@ def tl_fused_chunk_fwd_kernel(
K
:
T
.
Tensor
([
B
,
S
,
H
,
DK
],
dtype
),
# type: ignore
V
:
T
.
Tensor
([
B
,
S
,
H
,
DV
],
dtype
),
# type: ignore
O
:
T
.
Tensor
([
B
,
S
,
H
,
DV
],
accum_dtype
),
# type: ignore
final_state
:
T
.
Tensor
([
B
,
H
,
DK
,
DV
],
accum_dtype
)):
# type: ignore
final_state
:
T
.
Tensor
([
B
,
H
,
DK
,
DV
],
accum_dtype
),
):
# type: ignore
with
T
.
Kernel
(
NV
,
NK
,
B
*
H
)
as
(
i_v
,
i_k
,
i_bh
):
i_b
=
i_bh
//
H
i_h
=
i_bh
%
H
...
...
@@ -65,8 +66,8 @@ def tl_fused_chunk_fwd_kernel(
for
i
in
T
.
Pipelined
(
0
,
NT
):
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
BK
):
q
[
row
,
col
]
=
Q
[
i_b
,
i
*
chunk_size
+
row
,
i_h
,
i_k
*
BK
+
col
]
*
scale
T
.
copy
(
K
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
V
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
v
)
T
.
copy
(
K
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
V
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
v
)
T
.
gemm
(
q
,
k
,
s
,
clear_accum
=
True
,
transpose_B
=
True
)
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
chunk_size
):
...
...
@@ -77,12 +78,10 @@ def tl_fused_chunk_fwd_kernel(
T
.
gemm
(
k
,
v
,
h
,
transpose_A
=
True
)
T
.
gemm
(
q
,
h_shared
,
o
)
T
.
copy
(
o
,
o_shared
)
T
.
atomic_add
(
O
[
i_b
,
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
],
o_shared
)
T
.
atomic_add
(
O
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
o_shared
)
# Output final state
T
.
copy
(
h
,
final_state
[
i_b
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
])
T
.
copy
(
h
,
final_state
[
i_b
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
])
return
fused_chunk_linear_attn_fwd
...
...
@@ -91,38 +90,35 @@ def tl_fused_chunk_fwd(q, k, v):
B
,
S
,
H
,
D
=
q
.
shape
kernel
=
tl_fused_chunk_fwd_kernel
(
B
,
S
,
H
,
D
,
D
)
print
(
kernel
.
get_kernel_source
())
o
=
torch
.
zeros
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float32
)
o
=
torch
.
zeros
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float32
)
h
=
kernel
(
q
,
k
,
v
,
o
)
return
o
,
h
def
ref_program
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
scale
:
Optional
[
float
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
ref_program
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
scale
:
Optional
[
float
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
if
scale
is
None
:
scale
=
q
.
shape
[
-
1
]
**-
0.5
scale
=
q
.
shape
[
-
1
]
**
-
0.5
chunk_size
=
64
q
=
rearrange
(
q
,
'
b (n c) h d -> b h n c d
'
,
c
=
chunk_size
)
*
scale
k
=
rearrange
(
k
,
'
b (n c) h d -> b h n c d
'
,
c
=
chunk_size
)
v
=
rearrange
(
v
,
'
b (n c) h d -> b h n c d
'
,
c
=
chunk_size
)
q
=
rearrange
(
q
,
"
b (n c) h d -> b h n c d
"
,
c
=
chunk_size
)
*
scale
k
=
rearrange
(
k
,
"
b (n c) h d -> b h n c d
"
,
c
=
chunk_size
)
v
=
rearrange
(
v
,
"
b (n c) h d -> b h n c d
"
,
c
=
chunk_size
)
kv
=
k
.
transpose
(
-
1
,
-
2
)
@
v
kv
=
kv
.
cumsum
(
2
)
h
=
kv
[:,
:,
-
1
,
:,
:]
kv
=
torch
.
cat
([
torch
.
zeros_like
(
kv
[:,
:,
:
1
]),
kv
[:,
:,
:
-
1
]],
dim
=
2
)
inter
=
q
@
kv
intra
=
(
(
q
@
k
.
transpose
(
-
1
,
-
2
)).
masked_fill_
(
torch
.
triu
(
torch
.
ones
(
chunk_size
,
chunk_size
,
dtype
=
bool
,
device
=
q
.
device
),
diagonal
=
1
),
0
)
)
@
v
intra
=
(
(
q
@
k
.
transpose
(
-
1
,
-
2
)).
masked_fill_
(
torch
.
triu
(
torch
.
ones
(
chunk_size
,
chunk_size
,
dtype
=
bool
,
device
=
q
.
device
),
diagonal
=
1
),
0
)
)
@
v
o
=
inter
+
intra
return
rearrange
(
o
,
'
b h n c d -> b (n c) h d
'
),
h
return
rearrange
(
o
,
"
b h n c d -> b (n c) h d
"
),
h
def
main
(
B
=
1
,
S
=
512
,
H
=
16
,
D
=
128
):
q
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
# qk norm is necessary for linear attn
q
,
_
=
l2norm_fwd
(
q
)
...
...
@@ -131,25 +127,23 @@ def main(B=1, S=512, H=16, D=128):
o
,
h
=
tl_fused_chunk_fwd
(
q
,
k
,
v
)
o_ref
,
h_ref
=
ref_program
(
q
,
k
,
v
)
assert
torch
.
allclose
(
o
,
o_ref
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
'
o max err:
{
(
o
-
o_ref
).
abs
().
max
()
}
'
assert
torch
.
allclose
(
h
,
h_ref
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
'
h max err:
{
(
h
-
h_ref
).
abs
().
max
()
}
'
print
(
'
Passed all tests!✅
'
)
assert
torch
.
allclose
(
o
,
o_ref
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"
o max err:
{
(
o
-
o_ref
).
abs
().
max
()
}
"
assert
torch
.
allclose
(
h
,
h_ref
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"
h max err:
{
(
h
-
h_ref
).
abs
().
max
()
}
"
print
(
"
Passed all tests!✅
"
)
t1
=
do_bench
(
lambda
:
fused_chunk_linear_attn
(
q
,
k
,
v
,
output_final_state
=
True
,
normalize
=
False
),
backend
=
'cupti'
)
t2
=
do_bench
(
lambda
:
tl_fused_chunk_fwd
(
q
,
k
,
v
),
backend
=
'cupti'
)
print
(
f
'Triton latency:
{
t1
:.
3
f
}
ms'
)
print
(
f
'TileLang latency:
{
t2
:.
3
f
}
ms'
)
print
(
f
'Speedup:
{
t1
/
t2
:.
3
f
}
x'
)
t1
=
do_bench
(
lambda
:
fused_chunk_linear_attn
(
q
,
k
,
v
,
output_final_state
=
True
,
normalize
=
False
),
backend
=
"cupti"
)
t2
=
do_bench
(
lambda
:
tl_fused_chunk_fwd
(
q
,
k
,
v
),
backend
=
"cupti"
)
print
(
f
"Triton latency:
{
t1
:.
3
f
}
ms"
)
print
(
f
"TileLang latency:
{
t2
:.
3
f
}
ms"
)
print
(
f
"Speedup:
{
t1
/
t2
:.
3
f
}
x"
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--B
'
,
type
=
int
,
default
=
8
,
help
=
'
Batch size
'
)
parser
.
add_argument
(
'
--S
'
,
type
=
int
,
default
=
1024
,
help
=
'
Seq len
'
)
parser
.
add_argument
(
'
--H
'
,
type
=
int
,
default
=
32
,
help
=
'
Num heads
'
)
parser
.
add_argument
(
'
--D
'
,
type
=
int
,
default
=
128
,
help
=
'
Head dim
'
)
parser
.
add_argument
(
"
--B
"
,
type
=
int
,
default
=
8
,
help
=
"
Batch size
"
)
parser
.
add_argument
(
"
--S
"
,
type
=
int
,
default
=
1024
,
help
=
"
Seq len
"
)
parser
.
add_argument
(
"
--H
"
,
type
=
int
,
default
=
32
,
help
=
"
Num heads
"
)
parser
.
add_argument
(
"
--D
"
,
type
=
int
,
default
=
128
,
help
=
"
Head dim
"
)
args
=
parser
.
parse_args
()
main
(
args
.
B
,
args
.
S
,
args
.
H
,
args
.
D
)
examples/linear_attention/example_mamba_chunk_scan.py
View file @
29051439
...
...
@@ -9,6 +9,7 @@ import itertools
def
chunk_scan_triton
(
cb
,
x
,
dt
,
dA_cumsum
,
C
,
states
,
D
):
from
mamba_ssm.ops.triton.ssd_chunk_scan
import
_chunk_scan_fwd
out
,
_
=
_chunk_scan_fwd
(
cb
,
x
,
dt
,
dA_cumsum
,
C
,
states
,
D
)
return
out
...
...
@@ -43,14 +44,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
dt_segment_sum
=
dA_cumsum
[:,
:,
:,
:,
None
]
-
dA_cumsum
[:,
:,
:,
None
,
:]
decay
=
torch
.
exp
(
dt_segment_sum
)
scores_decay
=
cb
*
rearrange
(
decay
,
"b h c l s -> b c h l s"
)
causal_mask
=
torch
.
tril
(
torch
.
ones
(
chunk_size
,
chunk_size
,
device
=
x
.
device
,
dtype
=
bool
),
diagonal
=
0
)
causal_mask
=
torch
.
tril
(
torch
.
ones
(
chunk_size
,
chunk_size
,
device
=
x
.
device
,
dtype
=
bool
),
diagonal
=
0
)
scores_decay
=
scores_decay
.
masked_fill
(
~
causal_mask
,
0
)
out
=
torch
.
einsum
(
'bchls,bhcs,bcshp->bclhp'
,
scores_decay
.
to
(
x
.
dtype
),
dt
.
to
(
x
.
dtype
),
rearrange
(
x
,
"b (c s) h p -> b c s h p"
,
c
=
nchunks
))
out
=
torch
.
einsum
(
"bchls,bhcs,bcshp->bclhp"
,
scores_decay
.
to
(
x
.
dtype
),
dt
.
to
(
x
.
dtype
),
rearrange
(
x
,
"b (c s) h p -> b c s h p"
,
c
=
nchunks
)
)
state_decay_out
=
torch
.
exp
(
rearrange
(
dA_cumsum
,
"b h c l -> b c l h 1"
))
out_prev
=
torch
.
einsum
(
'bclhn,bchpn->bclhp'
,
rearrange
(
C
,
"b (c l) h n -> b c l h n"
,
c
=
nchunks
),
prev_states
.
to
(
C
.
dtype
))
*
state_decay_out
out_prev
=
(
torch
.
einsum
(
"bclhn,bchpn->bclhp"
,
rearrange
(
C
,
"b (c l) h n -> b c l h n"
,
c
=
nchunks
),
prev_states
.
to
(
C
.
dtype
))
*
state_decay_out
)
out
=
out
+
out_prev
out
=
rearrange
(
out
,
"b c l h p -> b (c l) h p"
)
if
D
is
not
None
:
...
...
@@ -61,12 +63,7 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
def
get_configs
():
iter_params
=
dict
(
block_M
=
[
64
,
128
,
256
],
block_N
=
[
32
,
64
],
block_K
=
[
64
,
128
,
256
],
block_Dstate
=
[
128
],
num_stages
=
[
1
,
2
,
3
,
4
,
5
])
iter_params
=
dict
(
block_M
=
[
64
,
128
,
256
],
block_N
=
[
32
,
64
],
block_K
=
[
64
,
128
,
256
],
block_Dstate
=
[
128
],
num_stages
=
[
1
,
2
,
3
,
4
,
5
])
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
...
...
@@ -77,7 +74,8 @@ def get_configs():
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
},
)
def
chunk_scan_fwd
(
batch
,
def
chunk_scan_fwd
(
batch
,
seqlen
,
chunk_size
,
ngroups
,
...
...
@@ -89,7 +87,8 @@ def chunk_scan_fwd(batch,
block_K
=
64
,
block_Dstate
=
128
,
num_stages
=
2
,
threads
=
128
):
threads
=
128
,
):
dtype
=
"float16"
accum_dtype
=
"float"
nchunks
=
T
.
ceildiv
(
seqlen
,
chunk_size
)
...
...
@@ -104,13 +103,13 @@ def chunk_scan_fwd(batch,
C
:
T
.
Tensor
((
batch
,
seqlen
,
ngroups
,
dstate
),
dtype
),
# type: ignore
prev_states
:
T
.
Tensor
((
batch
,
nchunks
,
nheads
,
headdim
,
dstate
),
dtype
),
# type: ignore
D
:
T
.
Tensor
((
nheads
),
dtype
),
# type: ignore
Output
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
)
# type: ignore
Output
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
),
# type: ignore
):
with
T
.
Kernel
(
nheads
,
T
.
ceildiv
(
chunk_size
,
block_M
)
*
T
.
ceildiv
(
headdim
,
block_N
),
batch
*
nchunks
,
threads
=
threads
)
as
(
bz
,
bx
,
by
,
):
with
T
.
Kernel
(
nheads
,
T
.
ceildiv
(
chunk_size
,
block_M
)
*
T
.
ceildiv
(
headdim
,
block_N
),
batch
*
nchunks
,
threads
=
threads
)
as
(
bz
,
bx
,
by
):
acc_o
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
acc_o_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
cb_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
,
scope
=
"shared.dyn"
)
...
...
@@ -136,27 +135,32 @@ def chunk_scan_fwd(batch,
m_idx
=
bx
//
T
.
ceildiv
(
headdim
,
block_N
)
n_idx
=
bx
%
T
.
ceildiv
(
headdim
,
block_N
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
acc_o_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
acc_o_shared
),
cb_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
cb_shared
),
x_residual_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
x_residual_shared
)
})
x_residual_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
x_residual_shared
),
}
)
T
.
no_set_max_nreg
()
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
m_idx
*
block_M
:(
m_idx
+
1
)
*
block_M
],
dA_cs_m_shared
)
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
m_idx
*
block_M
:
(
m_idx
+
1
)
*
block_M
],
dA_cs_m_shared
)
T
.
copy
(
dA_cs_m_shared
,
dA_cs_m_local
)
T
.
clear
(
acc_o
)
for
i
in
T
.
Parallel
(
block_M
):
scale_m_local
[
i
]
=
T
.
exp2
(
dA_cs_m_local
[
i
]
*
p
)
T
.
copy
(
C
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
//
(
nheads
//
ngroups
),
0
:
block_Dstate
],
C_shared
)
T
.
copy
(
prev_states
[
batch_idx
,
chunk_idx
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
,
0
:
block_Dstate
],
prev_state_shared
)
C
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
//
(
nheads
//
ngroups
),
0
:
block_Dstate
,
],
C_shared
,
)
T
.
copy
(
prev_states
[
batch_idx
,
chunk_idx
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
0
:
block_Dstate
],
prev_state_shared
)
T
.
gemm
(
C_shared
,
prev_state_shared
,
acc_o
,
transpose_B
=
True
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_o
[
i
,
j
]
*=
scale_m_local
[
i
]
...
...
@@ -165,34 +169,47 @@ def chunk_scan_fwd(batch,
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
cb
[
batch_idx
,
chunk_idx
,
bz
//
(
nheads
//
ngroups
),
m_idx
*
block_M
:(
m_idx
+
1
)
*
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
cb_shared
)
cb
[
batch_idx
,
chunk_idx
,
bz
//
(
nheads
//
ngroups
),
m_idx
*
block_M
:
(
m_idx
+
1
)
*
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
,
],
cb_shared
,
)
T
.
copy
(
cb_shared
,
cb_local
)
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
dA_cs_k_shared
)
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
dA_cs_k_shared
)
T
.
copy
(
dA_cs_k_shared
,
dA_cs_k_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
cb_local
[
i
,
j
]
=
cb_local
[
i
,
j
]
*
T
.
exp2
(
dA_cs_m_local
[
i
]
*
p
-
dA_cs_k_local
[
j
]
*
p
)
T
.
copy
(
dt
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
dt_shared
)
cb_local
[
i
,
j
]
=
cb_local
[
i
,
j
]
*
T
.
exp2
(
dA_cs_m_local
[
i
]
*
p
-
dA_cs_k_local
[
j
]
*
p
)
T
.
copy
(
dt
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
dt_shared
)
T
.
copy
(
dt_shared
,
dt_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
cb_local
[
i
,
j
]
*=
dt_local
[
j
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
cb_local
[
i
,
j
]
=
T
.
if_then_else
(
m_idx
*
block_M
+
i
>=
k
*
block_K
+
j
,
cb_local
[
i
,
j
],
0
)
cb_local
[
i
,
j
]
=
T
.
if_then_else
(
m_idx
*
block_M
+
i
>=
k
*
block_K
+
j
,
cb_local
[
i
,
j
],
0
)
T
.
copy
(
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
(
k
+
1
)
*
block_K
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
],
x_shared
)
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
(
k
+
1
)
*
block_K
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
],
x_shared
,
)
T
.
gemm
(
cb_local
,
x_shared
,
acc_o
)
D_local
[
0
]
=
D
[
bz
]
T
.
copy
(
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
],
x_residual_shared
)
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
],
x_residual_shared
,
)
T
.
copy
(
x_residual_shared
,
x_residual_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_o
[
i
,
j
]
+=
x_residual_local
[
i
,
j
]
*
D_local
[
0
]
...
...
@@ -200,27 +217,40 @@ def chunk_scan_fwd(batch,
T
.
copy
(
acc_o
,
acc_o_shared
)
T
.
copy
(
acc_o_shared
,
Output
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
])
Output
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
],
)
return
main
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
8
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
80
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--groups
'
,
type
=
int
,
default
=
1
,
help
=
'
groups
'
)
parser
.
add_argument
(
'
--seq_len
'
,
type
=
int
,
default
=
4096
,
help
=
'
sequence length
'
)
parser
.
add_argument
(
'
--chunk_size
'
,
type
=
int
,
default
=
256
,
help
=
'
chunk size
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
64
,
help
=
'
dim
'
)
parser
.
add_argument
(
'
--dstate
'
,
type
=
int
,
default
=
128
,
help
=
'
dstate
'
)
parser
.
add_argument
(
'
--tune
'
,
action
=
'
store_true
'
,
help
=
'
tune configs
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
8
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
80
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--groups
"
,
type
=
int
,
default
=
1
,
help
=
"
groups
"
)
parser
.
add_argument
(
"
--seq_len
"
,
type
=
int
,
default
=
4096
,
help
=
"
sequence length
"
)
parser
.
add_argument
(
"
--chunk_size
"
,
type
=
int
,
default
=
256
,
help
=
"
chunk size
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
64
,
help
=
"
dim
"
)
parser
.
add_argument
(
"
--dstate
"
,
type
=
int
,
default
=
128
,
help
=
"
dstate
"
)
parser
.
add_argument
(
"
--tune
"
,
action
=
"
store_true
"
,
help
=
"
tune configs
"
)
args
=
parser
.
parse_args
()
batch
,
heads
,
groups
,
seq_len
,
chunk_size
,
dim
,
dstate
=
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
seq_len
,
args
.
chunk_size
,
args
.
dim
,
args
.
dstate
batch
,
heads
,
groups
,
seq_len
,
chunk_size
,
dim
,
dstate
=
(
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
seq_len
,
args
.
chunk_size
,
args
.
dim
,
args
.
dstate
,
)
total_flops
=
2
*
batch
*
seq_len
*
chunk_size
*
heads
*
dim
*
0.5
+
2
*
batch
*
seq_len
*
heads
*
dim
*
dstate
if
(
not
args
.
tune
)
:
if
not
args
.
tune
:
kernel
=
chunk_scan_fwd
(
batch
,
seq_len
,
...
...
@@ -234,7 +264,8 @@ if __name__ == "__main__":
block_K
=
64
,
block_Dstate
=
128
,
num_stages
=
2
,
threads
=
128
)
threads
=
128
,
)
profiler
=
kernel
.
get_profiler
(
tilelang
.
TensorSupplyType
.
Normal
)
profiler
.
assert_allclose
(
ref_program
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"All checks pass."
)
...
...
examples/linear_attention/example_mamba_chunk_state.py
View file @
29051439
...
...
@@ -10,6 +10,7 @@ import itertools
def
chunk_state_triton
(
B
,
x
,
dt
,
dA_cumsum
):
from
mamba_ssm.ops.triton.ssd_chunk_state
import
_chunk_state_fwd
return
_chunk_state_fwd
(
B
,
x
,
dt
,
dA_cumsum
,
states_in_fp32
=
False
)
...
...
@@ -41,46 +42,33 @@ def ref_program(B, x, dt, dA_cumsum):
x
=
rearrange
(
x
,
"b (c l) h p -> b c l h p"
,
l
=
chunk_size
)
B
=
rearrange
(
B
,
"b (c l) ... -> b c l ..."
,
l
=
chunk_size
)
decay_states
=
torch
.
exp
((
dA_cumsum
[:,
:,
:,
-
1
:]
-
dA_cumsum
))
return
torch
.
einsum
(
"bclhn,bhcl,bhcl,bclhp->bchpn"
,
B
.
to
(
x
.
dtype
),
decay_states
.
to
(
x
.
dtype
),
dt
.
to
(
x
.
dtype
),
x
)
return
torch
.
einsum
(
"bclhn,bhcl,bhcl,bclhp->bchpn"
,
B
.
to
(
x
.
dtype
),
decay_states
.
to
(
x
.
dtype
),
dt
.
to
(
x
.
dtype
),
x
)
def
get_configs
():
iter_params
=
dict
(
block_M
=
[
64
,
128
],
block_N
=
[
32
,
64
,
128
],
block_K
=
[
32
,
64
],
num_stages
=
[
1
,
2
,
3
,
4
,
5
])
iter_params
=
dict
(
block_M
=
[
64
,
128
],
block_N
=
[
32
,
64
,
128
],
block_K
=
[
32
,
64
],
num_stages
=
[
1
,
2
,
3
,
4
,
5
])
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
4
])
def
chunk_state_fwd
(
batch
,
seqlen
,
chunk_size
,
ngroups
,
nheads
,
headdim
,
dstate
,
block_M
=
64
,
block_N
=
64
,
block_K
=
64
,
num_stages
=
2
,
threads
=
128
):
def
chunk_state_fwd
(
batch
,
seqlen
,
chunk_size
,
ngroups
,
nheads
,
headdim
,
dstate
,
block_M
=
64
,
block_N
=
64
,
block_K
=
64
,
num_stages
=
2
,
threads
=
128
):
dtype
=
"float16"
accum_dtype
=
"float"
nchunks
=
T
.
ceildiv
(
seqlen
,
chunk_size
)
p
=
1.44269504
@
T
.
prim_func
def
main
(
B
:
T
.
Tensor
((
batch
,
seqlen
,
ngroups
,
dstate
),
dtype
),
x
:
T
.
Tensor
(
(
batch
,
seqlen
,
nheads
,
headdim
),
dtype
),
dt
:
T
.
Tensor
(
(
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
dA_cumsum
:
T
.
Tensor
(
(
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
Output
:
T
.
Tensor
(
(
batch
,
nchunks
,
nheads
,
headdim
,
dstate
),
dtype
)):
with
T
.
Kernel
(
nheads
,
T
.
ceildiv
(
headdim
,
block_M
)
*
T
.
ceildiv
(
dstate
,
block_N
),
batch
*
nchunks
,
threads
=
threads
)
as
(
bz
,
bx
,
by
):
def
main
(
B
:
T
.
Tensor
((
batch
,
seqlen
,
ngroups
,
dstate
),
dtype
),
x
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
),
dt
:
T
.
Tensor
((
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
dA_cumsum
:
T
.
Tensor
((
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
Output
:
T
.
Tensor
((
batch
,
nchunks
,
nheads
,
headdim
,
dstate
),
dtype
),
):
with
T
.
Kernel
(
nheads
,
T
.
ceildiv
(
headdim
,
block_M
)
*
T
.
ceildiv
(
dstate
,
block_N
),
batch
*
nchunks
,
threads
=
threads
)
as
(
bz
,
bx
,
by
):
x_shared
=
T
.
alloc_shared
((
block_K
,
block_M
),
dtype
)
x_local
=
T
.
alloc_fragment
((
block_K
,
block_M
),
dtype
)
xt_local
=
T
.
alloc_fragment
((
block_M
,
block_K
),
dtype
)
...
...
@@ -101,20 +89,24 @@ def chunk_state_fwd(batch,
m_idx
=
bx
//
T
.
ceildiv
(
dstate
,
block_N
)
n_idx
=
bx
%
T
.
ceildiv
(
dstate
,
block_N
)
T
.
annotate_layout
({
x_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
x_shared
),
acc_o_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
acc_o_shared
)
})
T
.
annotate_layout
(
{
x_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
x_shared
),
acc_o_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
acc_o_shared
)}
)
dA_cs_last
[
0
]
=
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
chunk_size
-
1
]
T
.
clear
(
acc_o
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
(
k
+
1
)
*
block_K
,
bz
,
m_idx
*
block_M
:(
m_idx
+
1
)
*
block_M
],
x_shared
)
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
dA_cumsum_shared
)
T
.
copy
(
dt
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
dt_shared
)
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
(
k
+
1
)
*
block_K
,
bz
,
m_idx
*
block_M
:
(
m_idx
+
1
)
*
block_M
,
],
x_shared
,
)
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
dA_cumsum_shared
)
T
.
copy
(
dt
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
dt_shared
)
T
.
copy
(
dA_cumsum_shared
,
dA_cumsum_local
)
T
.
copy
(
dt_shared
,
dt_local
)
for
i
in
T
.
Parallel
(
block_K
):
...
...
@@ -123,47 +115,50 @@ def chunk_state_fwd(batch,
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
xt_local
[
i
,
j
]
=
x_local
[
j
,
i
]
*
scale
[
j
]
T
.
copy
(
B
[
batch_idx
,
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
(
k
+
1
)
*
block_K
,
bz
//
(
nheads
//
ngroups
),
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
],
B_shared
)
B
[
batch_idx
,
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
(
k
+
1
)
*
block_K
,
bz
//
(
nheads
//
ngroups
),
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
],
B_shared
,
)
T
.
gemm
(
xt_local
,
B_shared
,
acc_o
)
T
.
copy
(
acc_o
,
acc_o_shared
)
T
.
copy
(
acc_o_shared
,
Output
[
batch_idx
,
chunk_idx
,
bz
,
m_idx
*
block_M
:
(
m_idx
+
1
)
*
block_M
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
]
)
Output
[
batch_idx
,
chunk_idx
,
bz
,
m_idx
*
block_M
:
(
m_idx
+
1
)
*
block_M
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
],
)
return
main
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
8
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
80
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--groups
'
,
type
=
int
,
default
=
1
,
help
=
'
groups
'
)
parser
.
add_argument
(
'
--seq_len
'
,
type
=
int
,
default
=
4096
,
help
=
'
sequence length
'
)
parser
.
add_argument
(
'
--chunk_size
'
,
type
=
int
,
default
=
256
,
help
=
'
chunk size
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
64
,
help
=
'
dim
'
)
parser
.
add_argument
(
'
--dstate
'
,
type
=
int
,
default
=
128
,
help
=
'
dstate
'
)
parser
.
add_argument
(
'
--tune
'
,
action
=
'
store_true
'
,
help
=
'
tune configs
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
8
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
80
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--groups
"
,
type
=
int
,
default
=
1
,
help
=
"
groups
"
)
parser
.
add_argument
(
"
--seq_len
"
,
type
=
int
,
default
=
4096
,
help
=
"
sequence length
"
)
parser
.
add_argument
(
"
--chunk_size
"
,
type
=
int
,
default
=
256
,
help
=
"
chunk size
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
64
,
help
=
"
dim
"
)
parser
.
add_argument
(
"
--dstate
"
,
type
=
int
,
default
=
128
,
help
=
"
dstate
"
)
parser
.
add_argument
(
"
--tune
"
,
action
=
"
store_true
"
,
help
=
"
tune configs
"
)
args
=
parser
.
parse_args
()
batch
,
heads
,
groups
,
seq_len
,
chunk_size
,
dim
,
dstate
=
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
seq_len
,
args
.
chunk_size
,
args
.
dim
,
args
.
dstate
batch
,
heads
,
groups
,
seq_len
,
chunk_size
,
dim
,
dstate
=
(
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
seq_len
,
args
.
chunk_size
,
args
.
dim
,
args
.
dstate
,
)
total_flops
=
2
*
batch
*
seq_len
*
heads
*
dim
*
dstate
if
(
not
args
.
tune
)
:
if
not
args
.
tune
:
kernel
=
chunk_state_fwd
(
batch
,
seq_len
,
chunk_size
,
groups
,
heads
,
dim
,
dstate
,
block_M
=
64
,
block_N
=
128
,
block_K
=
64
,
num_stages
=
4
,
threads
=
128
)
batch
,
seq_len
,
chunk_size
,
groups
,
heads
,
dim
,
dstate
,
block_M
=
64
,
block_N
=
128
,
block_K
=
64
,
num_stages
=
4
,
threads
=
128
)
profiler
=
kernel
.
get_profiler
(
tilelang
.
TensorSupplyType
.
Normal
)
profiler
.
assert_allclose
(
ref_program
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"All checks pass."
)
...
...
examples/linear_attention/example_retention_fwd.py
View file @
29051439
...
...
@@ -13,13 +13,12 @@ def chunk_retention_fwd_kernel(
H
,
DK
,
DV
,
dtype
:
str
=
'
float16
'
,
dtype
:
str
=
"
float16
"
,
scale
:
float
=
None
,
)
->
torch
.
Tensor
:
if
scale
is
None
:
scale
=
DK
**-
0.5
accum_dtype
=
'
float
'
accum_dtype
=
"
float
"
chunk_size
=
64
BK
=
BV
=
64
# Set to 128 can be faster, but has some numerical differences with FLA
...
...
@@ -38,8 +37,8 @@ def chunk_retention_fwd_kernel(
with
T
.
Kernel
(
NV
,
NK
,
B
*
H
)
as
(
i_v
,
i_k
,
i_bh
):
i_b
=
i_bh
//
H
i_h
=
i_bh
%
H
log_decay
=
T
.
alloc_var
(
'
float32
'
)
log_decay
=
T
.
log2
(
1
-
T
.
exp2
(
-
5.
-
1.
*
i_h
))
# Head-specific log decay
log_decay
=
T
.
alloc_var
(
"
float32
"
)
log_decay
=
T
.
log2
(
1
-
T
.
exp2
(
-
5.
0
-
1.
0
*
i_h
))
# Head-specific log decay
q
=
T
.
alloc_shared
([
chunk_size
,
BK
],
dtype
)
k
=
T
.
alloc_shared
([
chunk_size
,
BK
],
dtype
)
...
...
@@ -56,14 +55,12 @@ def chunk_retention_fwd_kernel(
for
i
in
T
.
Pipelined
(
0
,
NT
):
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
BK
):
q
[
row
,
col
]
=
Q
[
i_b
,
i
*
chunk_size
+
row
,
i_h
,
i_k
*
BK
+
col
]
*
scale
T
.
copy
(
K
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
V
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
v
)
T
.
copy
(
K
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
V
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
v
)
T
.
gemm
(
q
,
k
,
s
,
clear_accum
=
True
,
transpose_B
=
True
)
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
chunk_size
):
s_shared
[
row
,
col
]
=
T
.
if_then_else
(
row
>=
col
,
s
[
row
,
col
]
*
T
.
exp2
(
(
row
-
col
)
*
log_decay
),
0
)
s_shared
[
row
,
col
]
=
T
.
if_then_else
(
row
>=
col
,
s
[
row
,
col
]
*
T
.
exp2
((
row
-
col
)
*
log_decay
),
0
)
T
.
copy
(
h
,
h_shared
)
T
.
gemm
(
q
,
h_shared
,
o
,
clear_accum
=
True
)
...
...
@@ -75,9 +72,7 @@ def chunk_retention_fwd_kernel(
v
[
row
,
col
]
=
v
[
row
,
col
]
*
T
.
exp2
((
chunk_size
-
row
-
1
)
*
log_decay
)
for
row
,
col
in
T
.
Parallel
(
BK
,
BV
):
h
[
row
,
col
]
=
T
.
exp2
(
chunk_size
*
log_decay
)
*
h
[
row
,
col
]
T
.
copy
(
o
,
O
[
i_k
,
i_b
,
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
])
T
.
copy
(
o
,
O
[
i_k
,
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
])
T
.
gemm
(
k
,
v
,
h
,
transpose_A
=
True
)
return
chunk_retention_fwd
...
...
@@ -89,24 +84,24 @@ def postprocess(o):
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--B
'
,
type
=
int
,
default
=
8
,
help
=
'
Batch size
'
)
parser
.
add_argument
(
'
--S
'
,
type
=
int
,
default
=
4096
,
help
=
'
Seq len
'
)
parser
.
add_argument
(
'
--H
'
,
type
=
int
,
default
=
32
,
help
=
'
Num heads
'
)
parser
.
add_argument
(
'
--D
'
,
type
=
int
,
default
=
128
,
help
=
'
Head dim
'
)
parser
.
add_argument
(
"
--B
"
,
type
=
int
,
default
=
8
,
help
=
"
Batch size
"
)
parser
.
add_argument
(
"
--S
"
,
type
=
int
,
default
=
4096
,
help
=
"
Seq len
"
)
parser
.
add_argument
(
"
--H
"
,
type
=
int
,
default
=
32
,
help
=
"
Num heads
"
)
parser
.
add_argument
(
"
--D
"
,
type
=
int
,
default
=
128
,
help
=
"
Head dim
"
)
args
=
parser
.
parse_args
()
B
,
S
,
H
,
D
=
args
.
B
,
args
.
S
,
args
.
H
,
args
.
D
total_flops
=
2.0
*
B
*
S
*
S
*
H
*
D
# causal
q
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
kernel
=
chunk_retention_fwd_kernel
(
B
,
S
,
H
,
D
,
D
)
t
=
do_bench
(
lambda
:
postprocess
(
kernel
(
q
,
k
,
v
)),
warmup
=
25
,
rep
=
100
)
print
(
f
'
Tilelang latency:
{
t
:.
3
f
}
ms
'
)
print
(
f
'
Tilelang TFLOPs:
{
total_flops
/
t
*
1e-9
}
'
)
print
(
f
"
Tilelang latency:
{
t
:.
3
f
}
ms
"
)
print
(
f
"
Tilelang TFLOPs:
{
total_flops
/
t
*
1e-9
}
"
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
main
()
examples/minference/example_vertical_slash_sparse_attn.py
View file @
29051439
...
...
@@ -15,12 +15,11 @@ from tilelang.profiler import do_bench
@
tilelang
.
jit
(
out_idx
=
[
3
])
def
_tl_vs_sparse_flashattn
(
batch
,
heads
,
seq_len
,
dim
,
vertical_size
,
slash_size
):
block_M
=
64
block_N
=
64
num_stages
=
2
threads
=
128
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
seq_blocks
=
(
seq_len
+
block_M
-
1
)
//
block_M
...
...
@@ -30,15 +29,13 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
offset_shape
=
count_shape
+
[
slash_size
]
index_shape
=
count_shape
+
[
vertical_size
]
vertical_size_round
,
slash_size_round
=
tilelang
.
next_power_of_2
(
vertical_size
),
tilelang
.
next_power_of_2
(
slash_size
)
vertical_size_round
,
slash_size_round
=
tilelang
.
next_power_of_2
(
vertical_size
),
tilelang
.
next_power_of_2
(
slash_size
)
dtype
=
"float16"
accum_dtype
=
"float"
int_dtype
=
"int32"
def
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
):
@
T
.
macro
def
Prefetch
(
K
:
T
.
Tensor
(
shape
,
dtype
),
...
...
@@ -53,13 +50,11 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
):
with
T
.
attr
(
"default"
,
"async_scope"
,
1
):
for
i
,
j
in
T
.
Parallel
(
block_N
,
dim
):
K_shared
[
i
,
j
]
=
T
.
if_then_else
(
k
+
i
<
column_count
,
K
[
bz
,
by
,
column_index
[
k
+
i
],
j
],
0
)
K_shared
[
i
,
j
]
=
T
.
if_then_else
(
k
+
i
<
column_count
,
K
[
bz
,
by
,
column_index
[
k
+
i
],
j
],
0
)
with
T
.
attr
(
"default"
,
"async_scope"
,
1
):
for
i
,
j
in
T
.
Parallel
(
block_N
,
dim
):
V_shared
[
i
,
j
]
=
T
.
if_then_else
(
k
+
i
<
column_count
,
V
[
bz
,
by
,
column_index
[
k
+
i
],
j
],
0
)
V_shared
[
i
,
j
]
=
T
.
if_then_else
(
k
+
i
<
column_count
,
V
[
bz
,
by
,
column_index
[
k
+
i
],
j
],
0
)
T
.
ptx_commit_group
()
...
...
@@ -118,7 +113,6 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
ColumnIndex
:
T
.
Tensor
(
index_shape
,
int_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
256
)
as
(
bc
,
by
,
bz
):
bx
=
T
.
ceildiv
(
seq_len
,
block_M
)
-
1
-
bc
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
2
,
block_N
,
dim
],
dtype
)
...
...
@@ -143,9 +137,11 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
T
.
create_list_of_mbarrier
([
128
]
*
9
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
})
}
)
block_count
[
0
]
=
BlockCount
[
bz
,
by
,
bx
]
column_count
[
0
]
=
ColumnCount
[
bz
,
by
,
bx
]
...
...
@@ -162,15 +158,15 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
if
tid
>=
128
:
T
.
annotate_producer_reg_dealloc
()
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
mbarrier_arrive
(
mbarrier
=
8
)
for
bi
in
T
.
serial
(
block_count
[
0
]):
k
=
block_offset
[
bi
]
T
.
mbarrier_wait_parity
(
mbarrier
=
bi
%
2
+
4
,
parity
=
(((
bi
&
3
)
>>
1
)
^
1
))
T
.
copy
(
K
[
bz
,
by
,
k
:
k
+
block_N
,
:],
K_shared
[
bi
%
2
,
:,
:])
T
.
copy
(
K
[
bz
,
by
,
k
:
k
+
block_N
,
:],
K_shared
[
bi
%
2
,
:,
:])
T
.
mbarrier_arrive
(
mbarrier
=
bi
%
2
)
T
.
mbarrier_wait_parity
(
mbarrier
=
bi
%
2
+
6
,
parity
=
(((
bi
&
3
)
>>
1
)
^
1
))
T
.
copy
(
V
[
bz
,
by
,
k
:
k
+
block_N
,
:],
V_shared
[
bi
%
2
,
:,
:])
T
.
copy
(
V
[
bz
,
by
,
k
:
k
+
block_N
,
:],
V_shared
[
bi
%
2
,
:,
:])
T
.
mbarrier_arrive
(
mbarrier
=
bi
%
2
+
2
)
else
:
T
.
annotate_consumer_reg_alloc
()
...
...
@@ -181,16 +177,10 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
for
bi
in
T
.
serial
(
block_count
[
0
]):
k
=
block_offset
[
bi
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
T
.
mbarrier_wait_parity
(
mbarrier
=
bi
%
2
,
parity
=
((
bi
&
3
)
>>
1
))
T
.
gemm
(
Q_shared
,
K_shared
[
bi
%
2
,
:,
:],
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
[
bi
%
2
,
:,
:],
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
mbarrier_arrive
(
mbarrier
=
bi
%
2
+
4
)
T
.
copy
(
scores_max
,
scores_max_prev
)
...
...
@@ -200,20 +190,15 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
=
acc_o
[
i
,
j
]
*
scores_scale
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
T
.
mbarrier_wait_parity
(
mbarrier
=
bi
%
2
+
2
,
parity
=
(((
bi
&
3
)
>>
1
)))
T
.
gemm
(
acc_s_cast
,
V_shared
[
bi
%
2
,
:,
:],
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
mbarrier_wait_parity
(
mbarrier
=
bi
%
2
+
2
,
parity
=
((
bi
&
3
)
>>
1
))
T
.
gemm
(
acc_s_cast
,
V_shared
[
bi
%
2
,
:,
:],
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
mbarrier_arrive
(
mbarrier
=
bi
%
2
+
6
)
...
...
@@ -223,38 +208,85 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
if
column_count
[
0
]
!=
0
:
Prefetch
(
K
,
V
,
K_shared_1
,
V_shared_1
,
column_index
,
column_count
[
0
],
0
,
bz
,
by
)
Prefetch
(
K
,
V
,
K_shared_1
,
V_shared_1
,
column_index
,
column_count
[
0
],
0
,
bz
,
by
)
for
bi
in
T
.
serial
(
T
.
ceildiv
(
column_count
[
0
],
block_N
)
-
1
):
k
=
bi
*
block_N
if
bi
%
2
==
0
:
Prefetch
(
K
,
V
,
K_shared_2
,
V_shared_2
,
column_index
,
column_count
[
0
],
k
+
block_N
,
bz
,
by
)
Prefetch
(
K
,
V
,
K_shared_2
,
V_shared_2
,
column_index
,
column_count
[
0
],
k
+
block_N
,
bz
,
by
)
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
k
,
column_count
[
0
],
Q_shared
,
K_shared_1
,
V_shared_1
,
scores_scale
,
scores_sum
,
logsum
,
1
)
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
k
,
column_count
[
0
],
Q_shared
,
K_shared_1
,
V_shared_1
,
scores_scale
,
scores_sum
,
logsum
,
1
,
)
else
:
Prefetch
(
K
,
V
,
K_shared_1
,
V_shared_1
,
column_index
,
column_count
[
0
],
k
+
block_N
,
bz
,
by
)
Prefetch
(
K
,
V
,
K_shared_1
,
V_shared_1
,
column_index
,
column_count
[
0
],
k
+
block_N
,
bz
,
by
)
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
k
,
column_count
[
0
],
Q_shared
,
K_shared_2
,
V_shared_2
,
scores_scale
,
scores_sum
,
logsum
,
1
)
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
k
,
column_count
[
0
],
Q_shared
,
K_shared_2
,
V_shared_2
,
scores_scale
,
scores_sum
,
logsum
,
1
,
)
if
T
.
ceildiv
(
column_count
[
0
],
block_N
)
%
2
==
0
:
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
T
.
ceildiv
(
column_count
[
0
],
block_N
)
*
block_N
-
block_N
,
column_count
[
0
],
Q_shared
,
K_shared_2
,
V_shared_2
,
scores_scale
,
scores_sum
,
logsum
,
0
)
column_count
[
0
],
Q_shared
,
K_shared_2
,
V_shared_2
,
scores_scale
,
scores_sum
,
logsum
,
0
,
)
else
:
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
T
.
ceildiv
(
column_count
[
0
],
block_N
)
*
block_N
-
block_N
,
column_count
[
0
],
Q_shared
,
K_shared_1
,
V_shared_1
,
scores_scale
,
scores_sum
,
logsum
,
0
)
column_count
[
0
],
Q_shared
,
K_shared_1
,
V_shared_1
,
scores_scale
,
scores_sum
,
logsum
,
0
,
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
vs_sparse_flashattn_ws
...
...
@@ -470,11 +502,8 @@ def vertical_slash_sparse_attention(
import
os
current_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sources
=
[
os
.
path
.
join
(
current_dir
,
'ops'
,
'kernels.cpp'
),
os
.
path
.
join
(
current_dir
,
'ops'
,
'vertical_slash_index.cu'
)
]
ops
=
load
(
name
=
'convert'
,
sources
=
sources
,
verbose
=
False
)
sources
=
[
os
.
path
.
join
(
current_dir
,
"ops"
,
"kernels.cpp"
),
os
.
path
.
join
(
current_dir
,
"ops"
,
"vertical_slash_index.cu"
)]
ops
=
load
(
name
=
"convert"
,
sources
=
sources
,
verbose
=
False
)
convert_vertical_slash_indexes
=
ops
.
convert_vertical_slash_indexes
batch_size
,
num_heads
,
context_size
,
head_dim
=
query
.
shape
pad
=
(
block_size_M
-
context_size
)
&
(
block_size_M
-
1
)
...
...
@@ -485,15 +514,13 @@ def vertical_slash_sparse_attention(
value
=
torch
.
nn
.
functional
.
pad
(
value
,
[
0
,
0
,
0
,
pad
,
0
,
0
,
0
,
0
])
if
head_dim
not
in
[
16
,
32
,
64
,
128
,
256
,
512
]:
target_dim
=
2
**
math
.
ceil
(
math
.
log2
(
head_dim
))
-
head_dim
target_dim
=
2
**
math
.
ceil
(
math
.
log2
(
head_dim
))
-
head_dim
query
=
torch
.
nn
.
functional
.
pad
(
query
,
[
0
,
target_dim
,
0
,
0
,
0
,
0
,
0
,
0
])
key
=
torch
.
nn
.
functional
.
pad
(
key
,
[
0
,
target_dim
,
0
,
0
,
0
,
0
,
0
,
0
])
value
=
torch
.
nn
.
functional
.
pad
(
value
,
[
0
,
target_dim
,
0
,
0
,
0
,
0
,
0
,
0
])
v_idx
=
v_idx
.
to
(
torch
.
int32
).
reshape
((
batch_size
,
num_heads
,
-
1
)).
sort
(
dim
=-
1
,
descending
=
False
)[
0
]
s_idx
=
s_idx
.
to
(
torch
.
int32
).
reshape
((
batch_size
,
num_heads
,
-
1
)).
sort
(
dim
=-
1
,
descending
=
True
)[
0
]
v_idx
=
v_idx
.
to
(
torch
.
int32
).
reshape
((
batch_size
,
num_heads
,
-
1
)).
sort
(
dim
=-
1
,
descending
=
False
)[
0
]
s_idx
=
s_idx
.
to
(
torch
.
int32
).
reshape
((
batch_size
,
num_heads
,
-
1
)).
sort
(
dim
=-
1
,
descending
=
True
)[
0
]
seqlens
=
torch
.
tensor
([
context_size
]
*
query
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
query
.
device
)
sm_scale
=
head_dim
**-
0.5
...
...
@@ -506,8 +533,7 @@ def vertical_slash_sparse_attention(
block_size_N
,
)
tl_kernel
=
_tl_vs_sparse_flashattn
(
batch_size
,
num_heads
,
context_size
,
head_dim
,
v_idx
.
shape
[
2
],
s_idx
.
shape
[
2
])
tl_kernel
=
_tl_vs_sparse_flashattn
(
batch_size
,
num_heads
,
context_size
,
head_dim
,
v_idx
.
shape
[
2
],
s_idx
.
shape
[
2
])
def
run
(
is_triton
:
bool
=
True
):
if
is_triton
:
...
...
@@ -525,8 +551,7 @@ def vertical_slash_sparse_attention(
block_size_N
,
)
else
:
out
=
tl_kernel
(
query
,
key
,
value
,
block_count
,
block_offset
,
column_count
,
column_index
)
out
=
tl_kernel
(
query
,
key
,
value
,
block_count
,
block_offset
,
column_count
,
column_index
)
return
out
[...,
:
context_size
,
:
head_dim
]
return
run
...
...
@@ -536,8 +561,7 @@ def sum_all_diagonal_matrix(mat: torch.tensor):
b
,
h
,
n
,
m
=
mat
.
shape
zero_mat
=
torch
.
zeros
((
b
,
h
,
n
,
n
)).
to
(
mat
.
device
)
# Zero matrix used for padding
mat_padded
=
torch
.
cat
((
zero_mat
,
mat
,
zero_mat
),
-
1
)
# pads the matrix on left and right
mat_strided
=
mat_padded
.
as_strided
(
(
1
,
1
,
n
,
n
+
m
),
(
1
,
n
*
(
2
*
n
+
m
),
2
*
n
+
m
+
1
,
1
))
# Change the strides
mat_strided
=
mat_padded
.
as_strided
((
1
,
1
,
n
,
n
+
m
),
(
1
,
n
*
(
2
*
n
+
m
),
2
*
n
+
m
+
1
,
1
))
# Change the strides
sum_diags
=
torch
.
sum
(
mat_strided
,
2
)
# Sums the resulting matrix's columns
return
sum_diags
[:,
:,
1
:]
...
...
@@ -559,24 +583,23 @@ def main(argv=None):
vertical_size
,
slash_size
=
args
.
vertical_size
,
args
.
slash_size
torch
.
manual_seed
(
0
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
q_len
=
SEQ_LEN
vertical_size
,
slash_size
=
min
(
q_len
,
vertical_size
),
min
(
q_len
,
slash_size
)
last_q
=
64
qk
=
torch
.
einsum
(
'
bhmk, bhnk -> bhmn
'
,
q
[:,
:,
-
last_q
:,
:],
k
)
qk
=
torch
.
einsum
(
"
bhmk, bhnk -> bhmn
"
,
q
[:,
:,
-
last_q
:,
:],
k
)
arange
=
torch
.
arange
(
last_q
,
device
=
"cuda"
)
qk
[:,
:,
:,
-
last_q
:]
=
torch
.
where
(
arange
[
None
,
None
,
:,
None
]
>=
arange
[
None
,
None
,
None
,
:],
qk
[:,
:,
:,
-
last_q
:],
-
torch
.
inf
)
qk
[:,
:,
:,
-
last_q
:]
=
torch
.
where
(
arange
[
None
,
None
,
:,
None
]
>=
arange
[
None
,
None
,
None
,
:],
qk
[:,
:,
:,
-
last_q
:],
-
torch
.
inf
)
qk
=
torch
.
nn
.
functional
.
softmax
(
qk
,
dim
=-
1
,
dtype
=
torch
.
float32
)
vertical
=
qk
.
sum
(
-
2
,
keepdim
=
True
)
vertical
[...,
:
30
]
=
torch
.
inf
vertical_topk
=
torch
.
topk
(
vertical
,
vertical_size
,
-
1
).
indices
slash
=
sum_all_diagonal_matrix
(
qk
)[...,
:
-
last_q
+
1
]
slash
=
sum_all_diagonal_matrix
(
qk
)[...,
:
-
last_q
+
1
]
slash
[...,
-
30
:]
=
torch
.
inf
slash
=
(
q_len
-
1
)
-
torch
.
topk
(
slash
,
slash_size
,
-
1
).
indices
...
...
examples/norm/rms_norm.py
View file @
29051439
...
...
@@ -45,7 +45,7 @@ def rms_norm(M, N, blk_m):
A_local
=
T
.
alloc_fragment
((
blk_m
,
N
),
dtype
)
A_powsum
=
T
.
alloc_fragment
((
blk_m
,),
dtype
)
T
.
copy
(
A
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:],
A_shared
)
T
.
copy
(
A
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:],
A_shared
)
T
.
copy
(
A_shared
,
A_local
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
A_pow_local
[
i
,
j
]
=
A_local
[
i
,
j
]
*
A_local
[
i
,
j
]
...
...
@@ -54,7 +54,7 @@ def rms_norm(M, N, blk_m):
A_powsum
[
i
]
=
T
.
rsqrt
(
A_powsum
[
i
]
/
N
+
1e-12
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
A_local
[
i
,
j
]
*=
A_powsum
[
i
]
T
.
copy
(
A_local
,
B
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:])
T
.
copy
(
A_local
,
B
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:])
return
main
...
...
examples/norm/test_rms_norm.py
View file @
29051439
...
...
@@ -45,7 +45,7 @@ def rms_norm(M, N, blk_m):
A_local
=
T
.
alloc_fragment
((
blk_m
,
N
),
dtype
)
A_powsum
=
T
.
alloc_fragment
((
blk_m
,),
dtype
)
T
.
copy
(
A
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:],
A_shared
)
T
.
copy
(
A
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:],
A_shared
)
T
.
copy
(
A_shared
,
A_local
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
A_pow_local
[
i
,
j
]
=
A_local
[
i
,
j
]
*
A_local
[
i
,
j
]
...
...
@@ -54,7 +54,7 @@ def rms_norm(M, N, blk_m):
A_powsum
[
i
]
=
T
.
rsqrt
(
A_powsum
[
i
]
/
N
+
1e-12
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
A_local
[
i
,
j
]
*=
A_powsum
[
i
]
T
.
copy
(
A_local
,
B
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:])
T
.
copy
(
A_local
,
B
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:])
return
main
...
...
examples/online_softmax/online_softmax.py
View file @
29051439
...
...
@@ -33,7 +33,7 @@ def softmax_kernel(
T
.
fill
(
lse
,
-
T
.
infinity
(
accum_dtype
))
for
i_n
in
T
.
Pipelined
(
0
,
NN
):
T
.
copy
(
X
[
i_m
,
i_n
*
BN
:
(
i_n
+
1
)
*
BN
],
x
)
T
.
copy
(
X
[
i_m
,
i_n
*
BN
:
(
i_n
+
1
)
*
BN
],
x
)
T
.
reduce_max
(
x
,
max_x
,
dim
=
0
,
clear
=
True
)
...
...
@@ -45,12 +45,12 @@ def softmax_kernel(
lse
[
0
]
=
max_x
[
0
]
*
scale
+
T
.
log2
(
T
.
exp2
(
lse
[
0
]
-
max_x
[
0
]
*
scale
)
+
sum_exp_x
[
0
])
for
i_n
in
T
.
Pipelined
(
0
,
NN
):
T
.
copy
(
X
[
i_m
,
i_n
*
BN
:
(
i_n
+
1
)
*
BN
],
x
)
T
.
copy
(
X
[
i_m
,
i_n
*
BN
:
(
i_n
+
1
)
*
BN
],
x
)
for
j
in
T
.
Parallel
(
BN
):
y
[
j
]
=
T
.
exp2
(
x
[
j
]
*
scale
-
lse
[
0
])
T
.
copy
(
y
,
Y
[
i_m
,
i_n
*
BN
:
(
i_n
+
1
)
*
BN
])
T
.
copy
(
y
,
Y
[
i_m
,
i_n
*
BN
:
(
i_n
+
1
)
*
BN
])
return
main
...
...
@@ -69,4 +69,4 @@ t1 = do_bench(lambda: X.softmax(dim=1), warmup=25, rep=100)
t2
=
do_bench
(
lambda
:
kernel
(
X
),
warmup
=
25
,
rep
=
100
)
print
(
f
"torch latency:
{
t1
:.
3
f
}
ms"
)
print
(
f
"TileLang latency:
{
t2
:.
3
f
}
ms"
)
print
(
f
"Speedup:
{
t1
/
t2
:.
3
f
}
x"
)
print
(
f
"Speedup:
{
t1
/
t2
:.
3
f
}
x"
)
examples/plot_layout/fragment_mfma_load_a.py
View file @
29051439
...
...
@@ -11,10 +11,9 @@ from tilelang.intrinsics.mfma_layout import (
)
def
make_mfma_load_base_layout
(
dtype
:
str
=
"float16"
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
,
k_dim
:
int
=
16
,
transposed
:
bool
=
False
)
->
T
.
Fragment
:
def
make_mfma_load_base_layout
(
dtype
:
str
=
"float16"
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
,
k_dim
:
int
=
16
,
transposed
:
bool
=
False
)
->
T
.
Fragment
:
"""
Create a layout function for storing MFMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mfma_store_layout` to
...
...
@@ -72,12 +71,10 @@ def make_mfma_load_base_layout(dtype: str = "float16",
# so the b matrix expected a transposed basic layout
transform_func
:
Callable
=
None
if
matrix
==
"A"
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
micro_size_s
,
micro_size_r
=
micro_size_x
,
micro_size_k
elif
matrix
==
"B"
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
micro_size_s
,
micro_size_r
=
micro_size_k
,
micro_size_y
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
...
...
@@ -120,14 +117,11 @@ print(base_layout)
plot_layout
(
base_layout
,
name
=
"base_layout"
)
# warp layout 32x32
warp_layout
=
base_layout
.
repeat
([
warp_rows
,
warp_cols
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
warp_layout
=
base_layout
.
repeat
([
warp_rows
,
warp_cols
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
print
(
warp_layout
)
plot_layout
(
warp_layout
,
name
=
"warp_layout"
)
# block layout 64x32
block_layout
=
warp_layout
.
repeat
([
block_rows
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
block_cols
)
block_layout
=
warp_layout
.
repeat
([
block_rows
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
block_cols
)
print
(
block_layout
)
plot_layout
(
block_layout
,
name
=
"block_layout"
)
examples/plot_layout/fragment_mma_load_a.py
View file @
29051439
...
...
@@ -5,9 +5,7 @@ from tvm.tir import IndexMap
from
tilelang.intrinsics.utils
import
get_mma_micro_size
def
make_mma_load_base_layout
(
dtype
:
str
=
"float16"
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
,
transposed
:
bool
=
False
)
->
T
.
Fragment
:
def
make_mma_load_base_layout
(
dtype
:
str
=
"float16"
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
,
transposed
:
bool
=
False
)
->
T
.
Fragment
:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
...
...
@@ -36,6 +34,7 @@ def make_mma_load_base_layout(dtype: str = "float16",
shared_16x16_to_mma_32x8_layout_sr_b
,
shared_16x32_to_mma_32x16_layout_sr_b
,
)
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
dtype_bits
=
DataType
(
dtype
).
bits
# s represents spatial axis
...
...
@@ -67,12 +66,10 @@ def make_mma_load_base_layout(dtype: str = "float16",
# so the b matrix expected a transposed basic layout
transform_func
:
Callable
=
None
if
matrix
==
"A"
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
micro_size_s
,
micro_size_r
=
micro_size_x
,
micro_size_k
elif
matrix
==
"B"
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
micro_size_s
,
micro_size_r
=
micro_size_k
,
micro_size_y
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
...
...
examples/quickstart.py
View file @
29051439
...
...
@@ -7,7 +7,6 @@ import tilelang.language as T
# if not specified, it will be inferred from the input tensors during compile time
@
tilelang
.
jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
matmul_relu_kernel
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
examples/seer_attention/block_sparse_attn_tilelang.py
View file @
29051439
...
...
@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
...
@@ -30,15 +27,17 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
@
tilelang
.
jit
(
out_idx
=
[
4
],
pass_configs
=
{
out_idx
=
[
4
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
blocksparse_flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
downsample_len
,
is_causal
):
block_M
=
64
block_N
=
64
num_stages
=
0
threads
=
128
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
q_shape
=
[
batch
,
heads
,
seq_q
,
dim
]
kv_shape
=
[
batch
,
heads
,
seq_kv
,
dim
]
block_mask_shape
=
[
batch
,
heads
,
downsample_len
,
downsample_len
]
...
...
@@ -48,7 +47,6 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
block_mask_dtype
=
"int8"
def
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
):
@
T
.
macro
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
...
...
@@ -112,7 +110,7 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
block_mask
=
T
.
alloc_local
([
downsample_len
],
block_mask_dtype
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -124,33 +122,25 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
if
block_mask
[
k
]
!=
0
:
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
past_len
=
seq_kv
-
seq_q
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
+
past_len
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
+
past_len
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
main
...
...
@@ -165,44 +155,40 @@ def test_topk_sparse_attention():
torch
.
manual_seed
(
0
)
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
# Create sparse mask (downsampled to block level)
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
'cuda'
,
dtype
=
torch
.
float16
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
float16
)
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
# Run tilelang kernel
kernel
=
blocksparse_flashattn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
SEQ_LEN
,
D_HEAD
,
downsample_len
,
is_causal
=
True
)
kernel
=
blocksparse_flashattn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
SEQ_LEN
,
D_HEAD
,
downsample_len
,
is_causal
=
True
)
tilelang_output
=
kernel
(
q
,
k
,
v
,
block_mask
.
to
(
torch
.
int8
))
# Compute reference
# Expand block mask to full attention matrix
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
))
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
))
full_mask
=
full_mask
[...,
:
SEQ_LEN
,
:
SEQ_LEN
].
bool
()
full_mask
=
full_mask
&
torch
.
tril
(
torch
.
ones_like
(
full_mask
))
# Apply causal
# PyTorch reference implementation
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
'
-inf
'
))
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
print
(
"ref_output"
,
ref_output
)
print
(
"tilelang_output"
,
tilelang_output
)
# Verify accuracy
assert
torch
.
allclose
(
tilelang_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
"TileLang output doesn't match reference"
assert
torch
.
allclose
(
tilelang_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
"TileLang output doesn't match reference"
print
(
"Pass topk sparse attention test with qlen == klen"
)
...
...
@@ -215,42 +201,40 @@ def test_topk_sparse_attention_qlen_lt_klen():
torch
.
manual_seed
(
0
)
# Create inputs.
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
Q_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
Q_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
K_LEN
/
downsample_factor
)
# number of blocks along one dimension
x_ds
=
torch
.
randn
(
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
,
device
=
'cuda'
,
dtype
=
torch
.
float16
)
x_ds
=
torch
.
randn
(
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
# Force the first column to be high so that the first block is always selected.
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
kernel
=
blocksparse_flashattn
(
BATCH
,
N_HEADS
,
Q_LEN
,
K_LEN
,
D_HEAD
,
downsample_len
,
is_causal
=
True
)
kernel
=
blocksparse_flashattn
(
BATCH
,
N_HEADS
,
Q_LEN
,
K_LEN
,
D_HEAD
,
downsample_len
,
is_causal
=
True
)
print
(
kernel
.
get_kernel_source
())
tilelang_output
=
kernel
(
q
,
k
,
v
,
block_mask
.
to
(
torch
.
int8
))
past_len
=
K_LEN
-
Q_LEN
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
full_mask_full
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
)).
bool
()
full_mask_full
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
)).
bool
()
full_mask_full
=
full_mask_full
[...,
:
K_LEN
,
:
K_LEN
]
effective_mask
=
full_mask_full
[...,
past_len
:
K_LEN
,
:]
# shape: (B, H, Q_LEN, K_LEN)
i_global
=
torch
.
arange
(
past_len
,
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
1
)
# shape: (Q_LEN, 1)
j_global
=
torch
.
arange
(
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
0
)
# shape: (1, K_LEN)
causal_mask
=
(
j_global
<=
i_global
)
# shape: (Q_LEN, K_LEN)
causal_mask
=
j_global
<=
i_global
# shape: (Q_LEN, K_LEN)
final_mask
=
effective_mask
&
causal_mask
# shape: (B, H, Q_LEN, K_LEN)
attn
=
attn
.
masked_fill
(
~
final_mask
,
float
(
'
-inf
'
))
attn
=
attn
.
masked_fill
(
~
final_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
print
(
"ref_output"
,
ref_output
)
print
(
"tilelang_output"
,
tilelang_output
)
...
...
examples/seer_attention/block_sparse_attn_triton.py
View file @
29051439
...
...
@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
...
@@ -54,7 +51,6 @@ def _fwd_kernel_inner(
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
mask_val
=
tl
.
load
(
block_mask_ptr
+
k_block_col_idx
*
stride_bmask_n
)
if
mask_val
==
True
:
...
...
@@ -69,7 +65,7 @@ def _fwd_kernel_inner(
qk
*=
sm_scale
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
qk
+=
tl
.
where
(
offs_m
[:,
None
]
+
past_len
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
float
(
'
-inf
'
))
qk
+=
tl
.
where
(
offs_m
[:,
None
]
+
past_len
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
float
(
"
-inf
"
))
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
qk
-=
m_ij
[:,
None
]
...
...
@@ -149,7 +145,7 @@ def _fwd_kernel(
v_ptrs
=
V
+
off_v
mask_ptrs
=
block_mask_ptr
+
start_m
*
stride_bmm
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
'
inf
'
)
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"
inf
"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
...
...
@@ -185,24 +181,12 @@ def _fwd_kernel(
acc
=
acc
*
l_recip
acc
=
acc
.
to
(
Out
.
dtype
.
element_ty
)
off_o
=
off_z
*
stride_oz
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
None
,
:]
*
stride_od
off_o
=
off_z
*
stride_oz
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
None
,
:]
*
stride_od
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
N_CTX
)
def
_forward
(
ctx
,
q
,
k
,
v
,
block_sparse_mask
,
sm_scale
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
num_warps
=
None
,
num_stages
=
1
,
out
=
None
):
def
_forward
(
ctx
,
q
,
k
,
v
,
block_sparse_mask
,
sm_scale
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
num_warps
=
None
,
num_stages
=
1
,
out
=
None
):
assert
q
.
shape
[
-
1
]
==
k
.
shape
[
-
1
]
==
v
.
shape
[
-
1
]
assert
k
.
shape
[
2
]
==
v
.
shape
[
2
]
o
=
out
if
out
is
not
None
else
torch
.
empty_like
(
q
).
contiguous
()
...
...
@@ -247,7 +231,6 @@ def _forward(ctx,
class
_sparse_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
block_sparse_dense
,
sm_scale
):
# shape constraints
...
...
@@ -271,9 +254,9 @@ def test_topk_sparse_attention():
torch
.
manual_seed
(
0
)
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
# Create sparse mask (downsampled to block level)
...
...
@@ -281,9 +264,7 @@ def test_topk_sparse_attention():
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
print
(
"downsample_len"
,
downsample_len
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
x_ds
[:,
:,
:,
0
]
=
100
print
(
"x_ds.shape"
,
x_ds
.
shape
)
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
...
...
@@ -295,22 +276,21 @@ def test_topk_sparse_attention():
# Compute reference
# Expand block mask to full attention matrix
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
))
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
))
full_mask
=
full_mask
[...,
:
SEQ_LEN
,
:
SEQ_LEN
].
bool
()
full_mask
=
full_mask
&
torch
.
tril
(
torch
.
ones_like
(
full_mask
))
# Apply causal
# PyTorch reference implementation
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
'
-inf
'
))
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
# print("ref_output", ref_output)
# print("triton_output", triton_output)
# Verify accuracy
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
"Triton output doesn't match reference"
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
"Triton output doesn't match reference"
print
(
"Pass topk sparse attention test with qlen == klen"
)
...
...
@@ -322,16 +302,15 @@ def test_topk_sparse_attention_qlt_kl():
torch
.
manual_seed
(
0
)
# Create inputs.
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
Q_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
Q_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
# softmax scale
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
K_LEN
/
downsample_factor
)
# number of blocks along one dimension
x_ds
=
torch
.
randn
(
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
,
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
x_ds
=
torch
.
randn
(
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
# Force the first column to be high so that the first block is always selected.
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
...
...
@@ -340,26 +319,25 @@ def test_topk_sparse_attention_qlt_kl():
past_len
=
K_LEN
-
Q_LEN
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
full_mask_full
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
)).
bool
()
full_mask_full
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
)).
bool
()
full_mask_full
=
full_mask_full
[...,
:
K_LEN
,
:
K_LEN
]
effective_mask
=
full_mask_full
[...,
past_len
:
K_LEN
,
:]
# shape: (B, H, Q_LEN, K_LEN)
i_global
=
torch
.
arange
(
past_len
,
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
1
)
# shape: (Q_LEN, 1)
j_global
=
torch
.
arange
(
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
0
)
# shape: (1, K_LEN)
causal_mask
=
(
j_global
<=
i_global
)
# shape: (Q_LEN, K_LEN)
causal_mask
=
j_global
<=
i_global
# shape: (Q_LEN, K_LEN)
final_mask
=
effective_mask
&
causal_mask
# shape: (B, H, Q_LEN, K_LEN)
attn
=
attn
.
masked_fill
(
~
final_mask
,
float
(
'
-inf
'
))
attn
=
attn
.
masked_fill
(
~
final_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
# Verify accuracy.
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
"Triton output doesn't match reference when qlen < klen"
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
"Triton output doesn't match reference when qlen < klen"
print
(
"Pass topk sparse attention test with qlen < klen"
)
...
...
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
View file @
29051439
...
...
@@ -29,23 +29,21 @@ def matmul_sp(
@
T
.
prim_func
def
main
(
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
8
),
'
uint8
'
),
E
:
T
.
Tensor
((
M
,
K
//
8
),
"
uint8
"
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
8
),
'
uint8
'
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
8
),
"
uint8
"
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
annotate_layout
({
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
arch
=
"9.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
arch
=
"9.0"
,
block_k
=
block_K
),
})
T
.
annotate_layout
(
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
arch
=
"9.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
arch
=
"9.0"
,
block_k
=
block_K
),
}
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
8
],
E_shared
)
...
...
@@ -57,7 +55,7 @@ def matmul_sp(
return
main
def
generate_2_to_4_sparse_tensor
(
shape
,
dtype
=
torch
.
float32
,
device
=
'
cpu
'
):
def
generate_2_to_4_sparse_tensor
(
shape
,
dtype
=
torch
.
float32
,
device
=
"
cpu
"
):
if
shape
[
-
1
]
%
4
!=
0
:
raise
ValueError
(
"Last dimension must be divisible by 4 for 2:4 sparsity."
)
...
...
@@ -102,9 +100,9 @@ def run_gemm_sp(
num_threads
,
)
A
=
generate_2_to_4_sparse_tensor
((
M
,
K
),
dtype
=
torch
.
float16
,
device
=
'
cuda
'
)
A
=
generate_2_to_4_sparse_tensor
((
M
,
K
),
dtype
=
torch
.
float16
,
device
=
"
cuda
"
)
A_sparse
,
E
=
compress_sm90
(
A
,
block_k
=
block_K
,
transposed
=
False
)
B
=
torch
.
randn
((
K
,
N
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
B
=
torch
.
randn
((
K
,
N
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
C_sp
=
kernel
(
A_sparse
,
E
,
B
).
half
()
C
=
torch
.
matmul
(
A
,
B
)
...
...
examples/topk/example_topk.py
View file @
29051439
...
...
@@ -43,15 +43,12 @@ def tl_topk(
T
.
reduce_max
(
logits_frag
,
max_val
,
dim
=
1
,
clear
=
True
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
expand_max_idx
[
i
,
j
]
=
T
.
if_then_else
(
max_val
[
i
]
==
logits_frag
[
i
,
j
],
j
,
expand_max_idx
[
i
,
j
])
expand_max_idx
[
i
,
j
]
=
T
.
if_then_else
(
max_val
[
i
]
==
logits_frag
[
i
,
j
],
j
,
expand_max_idx
[
i
,
j
])
T
.
reduce_max
(
expand_max_idx
,
max_idx
,
dim
=
1
,
clear
=
True
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
logits_frag
[
i
,
j
]
=
T
.
if_then_else
(
max_val
[
i
]
==
logits_frag
[
i
,
j
],
-
10000.0
,
logits_frag
[
i
,
j
])
logits_frag
[
i
,
j
]
=
T
.
if_then_else
(
max_val
[
i
]
==
logits_frag
[
i
,
j
],
-
10000.0
,
logits_frag
[
i
,
j
])
for
i
in
T
.
Parallel
(
blk_m
):
topk_gates
[
bx
*
blk_m
+
i
,
k
]
=
max_val
[
i
]
...
...
@@ -61,7 +58,6 @@ def tl_topk(
def
ref_program
(
logits
,
top_k
):
top_k_gates
,
top_k_indices
=
logits
.
topk
(
top_k
,
dim
=
1
)
return
top_k_gates
,
top_k_indices
.
to
(
torch
.
int32
)
...
...
examples/visual_layout_inference/visual_layout_inference.py
View file @
29051439
...
...
@@ -7,10 +7,10 @@ import tilelang.language as T
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_LAYOUT_VISUALIZATION_ENABLE
:
True
,
tilelang
.
PassConfigKey
.
TL_LAYOUT_VISUALIZATION_FORMATS
:
"svg"
})
tilelang
.
PassConfigKey
.
TL_LAYOUT_VISUALIZATION_FORMATS
:
"svg"
,
},
)
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
gemm
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
@@ -49,12 +49,12 @@ def main():
print
(
"All check passed."
)
# print the layout visualization result and save figures to ./tmp.
'''
"""
C_local inferenced layout:
Shape: [32, 32] -> [8]
Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2
Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2]
'''
"""
if
__name__
==
"__main__"
:
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment