Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
467
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
627 additions
and
647 deletions
+627
-647
examples/hadamard_transform/example_hadamard.py
examples/hadamard_transform/example_hadamard.py
+15
-20
examples/lazy_jit/lazyjit.en.ipynb
examples/lazy_jit/lazyjit.en.ipynb
+41
-31
examples/lazy_jit/lazyjit.zh.ipynb
examples/lazy_jit/lazyjit.zh.ipynb
+41
-31
examples/linear_attention/example_linear_attn_bwd.py
examples/linear_attention/example_linear_attn_bwd.py
+55
-74
examples/linear_attention/example_linear_attn_fwd.py
examples/linear_attention/example_linear_attn_fwd.py
+40
-46
examples/linear_attention/example_mamba_chunk_scan.py
examples/linear_attention/example_mamba_chunk_scan.py
+110
-79
examples/linear_attention/example_mamba_chunk_state.py
examples/linear_attention/example_mamba_chunk_state.py
+57
-62
examples/linear_attention/example_retention_fwd.py
examples/linear_attention/example_retention_fwd.py
+22
-27
examples/minference/example_vertical_slash_sparse_attn.py
examples/minference/example_vertical_slash_sparse_attn.py
+123
-100
examples/norm/rms_norm.py
examples/norm/rms_norm.py
+2
-2
examples/norm/test_rms_norm.py
examples/norm/test_rms_norm.py
+2
-2
examples/online_softmax/online_softmax.py
examples/online_softmax/online_softmax.py
+6
-6
examples/plot_layout/fragment_mfma_load_a.py
examples/plot_layout/fragment_mfma_load_a.py
+7
-13
examples/plot_layout/fragment_mma_load_a.py
examples/plot_layout/fragment_mma_load_a.py
+4
-7
examples/quickstart.py
examples/quickstart.py
+3
-4
examples/seer_attention/block_sparse_attn_tilelang.py
examples/seer_attention/block_sparse_attn_tilelang.py
+48
-64
examples/seer_attention/block_sparse_attn_triton.py
examples/seer_attention/block_sparse_attn_triton.py
+24
-46
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
...s/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
+14
-16
examples/topk/example_topk.py
examples/topk/example_topk.py
+5
-9
examples/visual_layout_inference/visual_layout_inference.py
examples/visual_layout_inference/visual_layout_inference.py
+8
-8
No files found.
examples/hadamard_transform/example_hadamard.py
View file @
29051439
...
...
@@ -17,7 +17,7 @@ def is_pow_of_2(n):
def
hadamard
(
b
,
n
,
dtype
):
assert
is_pow_of_2
(
n
),
"n must be a power of 2"
assert
2
<=
n
<=
32768
,
"n must be in [2, 32768]"
elem_size
=
{
'
float32
'
:
4
,
'
float16
'
:
2
,
'
bfloat16
'
:
2
}[
dtype
]
elem_size
=
{
"
float32
"
:
4
,
"
float16
"
:
2
,
"
bfloat16
"
:
2
}[
dtype
]
logN
=
int
(
math
.
log2
(
n
))
threads
=
[
0
,
1
,
1
,
1
,
2
,
4
,
8
,
16
,
32
,
32
,
128
,
256
,
256
,
256
,
256
,
256
][
logN
]
...
...
@@ -40,23 +40,21 @@ def hadamard(b, n, dtype):
# print(f'{exchange_round=}')
@
T
.
macro
def
warp_shfl
(
local
:
T
.
Tensor
((
thread_elem
,),
dtype
),
buf
:
T
.
Tensor
((
thread_elem
,),
dtype
),
round
:
int
):
def
warp_shfl
(
local
:
T
.
Tensor
((
thread_elem
,),
dtype
),
buf
:
T
.
Tensor
((
thread_elem
,),
dtype
),
round
:
int
):
tx
=
T
.
get_thread_binding
(
0
)
for
i
in
T
.
serial
(
round
):
tx_stride
=
1
<<
i
another_tx
=
tx
^
tx_stride
sign
=
(
tx
>>
i
)
&
1
# get i-th lowest bit of tx, which determines the operation type for shared[tx, :]
sign
=
(
tx
>>
i
)
&
1
# get i-th lowest bit of tx, which determines the operation type for shared[tx, :]
for
j
in
T
.
Pipelined
(
thread_elem
,
num_stages
=
1
):
buf
[
j
]
=
T
.
tvm_warp_shuffle
(
0x
ffffffff
,
# mask of all threads
0x
FFFFFFFF
,
# mask of all threads
local
[
j
],
another_tx
%
warp_size
,
warp_size
,
warp_size
)
warp_size
,
)
local
[
j
]
=
T
.
if_then_else
(
sign
==
0
,
local
[
j
]
+
buf
[
j
],
buf
[
j
]
-
local
[
j
])
@
T
.
prim_func
...
...
@@ -78,10 +76,8 @@ def hadamard(b, n, dtype):
for
j
in
T
.
serial
(
chunknum
):
chunkbase
=
j
*
chunksize
for
k
in
T
.
serial
(
chunksize
//
2
):
local
[
chunkbase
+
k
]
=
local
[
chunkbase
+
k
]
+
local
[
chunkbase
+
k
+
chunksize
//
2
]
local
[
chunkbase
+
k
+
chunksize
//
2
]
=
local
[
chunkbase
+
k
]
-
2
*
local
[
chunkbase
+
k
+
chunksize
//
2
]
local
[
chunkbase
+
k
]
=
local
[
chunkbase
+
k
]
+
local
[
chunkbase
+
k
+
chunksize
//
2
]
local
[
chunkbase
+
k
+
chunksize
//
2
]
=
local
[
chunkbase
+
k
]
-
2
*
local
[
chunkbase
+
k
+
chunksize
//
2
]
# 3. Hadamard inside warp, n<=512
# In warp level, we rely on warp shuffle to exchange data inside each warp, without using shared memory
...
...
@@ -131,28 +127,27 @@ def ref_program(x: torch.Tensor):
assert
x
.
ndim
==
2
dim
=
x
.
shape
[
-
1
]
assert
is_pow_of_2
(
dim
)
return
F
.
linear
(
x
,
torch
.
tensor
(
scipy
.
linalg
.
hadamard
(
dim
,
dtype
=
float
),
dtype
=
x
.
dtype
,
device
=
x
.
device
))
return
F
.
linear
(
x
,
torch
.
tensor
(
scipy
.
linalg
.
hadamard
(
dim
,
dtype
=
float
),
dtype
=
x
.
dtype
,
device
=
x
.
device
))
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
64
,
help
=
'
Batch size
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
32768
,
help
=
'
Dimension
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
64
,
help
=
"
Batch size
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
32768
,
help
=
"
Dimension
"
)
args
=
parser
.
parse_args
()
B
,
D
=
args
.
batch
,
args
.
dim
x
=
torch
.
randn
((
B
,
D
),
device
=
'
cuda
'
)
kernel
=
hadamard
(
B
,
D
,
'
float32
'
)
x
=
torch
.
randn
((
B
,
D
),
device
=
"
cuda
"
)
kernel
=
hadamard
(
B
,
D
,
"
float32
"
)
y
=
kernel
(
x
)
y_ref
=
ref_program
(
x
)
torch
.
testing
.
assert_close
(
y
,
y_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
print
(
'
All tests passed.
'
)
print
(
"
All tests passed.
"
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Auto
)
latency
=
profiler
.
do_bench
(
warmup
=
100
)
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
main
()
examples/lazy_jit/lazyjit.en.ipynb
View file @
29051439
...
...
@@ -9,6 +9,7 @@
"source": [
"import sys\n",
"from pathlib import Path\n",
"\n",
"sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n",
"import tilelang\n",
"import torch\n",
...
...
@@ -61,7 +62,7 @@
" out_dtype: T.dtype = T.float32,\n",
" block_M: int = 128,\n",
" block_N: int = 128,\n",
" block_K: int = 32\n",
" block_K: int = 32
,
\n",
"):\n",
" M, K = A.shape\n",
" K, N = B.shape\n",
...
...
@@ -94,8 +95,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm(A, B)\n",
"\n",
"# check output is correct\n",
...
...
@@ -118,8 +119,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 1024, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 1024, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm(A, B, block_M=64, block_N=64)"
]
},
...
...
@@ -218,8 +219,8 @@
"source": [
"@tilelang.lazy_jit\n",
"def gemm_dyn_K(\n",
" A: T.Tensor[[int, T.dyn[
'K'
]], T.float16], # noqa: F821\n",
" B: T.Tensor[[T.dyn[
'K'
], int], T.float16], # noqa: F821\n",
" A: T.Tensor[[int, T.dyn[
\"K\"
]], T.float16],
# noqa: F821\n",
" B: T.Tensor[[T.dyn[
\"K\"
], int], T.float16],
# noqa: F821\n",
"):\n",
" M, K = A.shape\n",
" K, N = B.shape\n",
...
...
@@ -265,8 +266,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm_dyn_K(A, B)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...
...
@@ -295,18 +296,17 @@
"source": [
"from typing import Any\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"def as_contingious(\n",
" A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]\n",
"):\n",
"def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):\n",
" M, N = A.shape\n",
" B = T.empty((M, N), A.dtype)\n",
" block_M = 128\n",
" block_N = 128\n",
" with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
" T.copy(\n",
" A[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N],\n",
" B[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N]\n",
" A[bx * block_M
: (bx + 1) * block_M, by * block_N
: (by + 1) * block_N],\n",
" B[bx * block_M
: (bx + 1) * block_M, by * block_N
: (by + 1) * block_N]
,
\n",
" )\n",
" return B"
]
...
...
@@ -318,7 +318,7 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 1024, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 1024, device=
\"
cuda
\"
)\n",
"B = as_contingious(A[::2, ::2])\n",
"B_ref = A[::2, ::2].contiguous()\n",
"torch.testing.assert_close(B, B_ref)"
...
...
@@ -370,8 +370,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm_ptr(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...
...
@@ -416,8 +416,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...
...
@@ -496,18 +496,20 @@
"source": [
"from itertools import product\n",
"\n",
"\n",
"def get_configs():\n",
" return [\n",
" {\n",
"
'A'
: T.Tensor((1024, 1024), T.float32),\n",
"
'B'
: T.Tensor((1024, 1024), T.float32),\n",
"
'
block_M
'
: block_M,\n",
"
'
block_N
'
: block_N,\n",
"
'
block_K
'
: block_K,\n",
"
\"A\"
: T.Tensor((1024, 1024), T.float32),\n",
"
\"B\"
: T.Tensor((1024, 1024), T.float32),\n",
"
\"
block_M
\"
: block_M,\n",
"
\"
block_N
\"
: block_N,\n",
"
\"
block_K
\"
: block_K,\n",
" }\n",
" for block_M, block_N, block_K in product([32, 64], repeat=3)\n",
" ]\n",
"\n",
"\n",
"gemm.par_compile(get_configs())"
]
},
...
...
@@ -581,6 +583,7 @@
"def macro_with_ref(x: T.Ref):\n",
" x = 1 # noqa: F841\n",
"\n",
"\n",
"@T.prim_func\n",
"def foo(x: T.Tensor((2,))):\n",
" with T.Kernel(1) as _:\n",
...
...
@@ -591,6 +594,7 @@
" idx = T.alloc_var(T.int32, 0)\n",
" macro_with_ref(x[idx])\n",
"\n",
"\n",
"foo"
]
},
...
...
@@ -616,7 +620,7 @@
" A: T.Tensor[[T.dyn], Any],\n",
" fn,\n",
"):\n",
" N, = A.shape\n",
"
(
N,
)
= A.shape\n",
" B = T.empty((N,), dtype=A.dtype)\n",
" block_N = 128\n",
" with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n",
...
...
@@ -624,6 +628,8 @@
" idx = bx * block_N + i\n",
" B[idx] = fn(A[idx])\n",
" return B\n",
"\n",
"\n",
"@T.macro\n",
"def add_one(x):\n",
" return x + 1"
...
...
@@ -636,7 +642,7 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, device=
'
cuda
'
)\n",
"A = torch.randn(1024, device=
\"
cuda
\"
)\n",
"B = element_wise(A, add_one)\n",
"B_ref = A + 1\n",
"torch.testing.assert_close(B, B_ref)"
...
...
@@ -670,10 +676,11 @@
" var = var * 3 + 1\n",
" n31(x * 3 + 1, var)\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"def foo(A: T.Tensor[[1], T.int32], n: int):\n",
" with T.Kernel(1) as _:\n",
" n31(n, A[0])
\n
"
" n31(n, A[0])"
]
},
{
...
...
@@ -694,7 +701,7 @@
}
],
"source": [
"A = torch.tensor([100], dtype=torch.int32, device=
'
cuda
'
)\n",
"A = torch.tensor([100], dtype=torch.int32, device=
\"
cuda
\"
)\n",
"foo(A, 5)\n",
"A"
]
...
...
@@ -745,12 +752,15 @@
"def sincos(x):\n",
" return T.sin(x), T.cos(x)\n",
"\n",
"\n",
"@T.prim_func\n",
"def foo():\n",
" with T.Kernel(32) as x:\n",
" s, c = sincos(x)\n",
" a = s + c # noqa: F841\n",
" b = s - c # noqa: F841\n",
"\n",
"\n",
"foo"
]
}
...
...
examples/lazy_jit/lazyjit.zh.ipynb
View file @
29051439
...
...
@@ -9,6 +9,7 @@
"source": [
"import sys\n",
"from pathlib import Path\n",
"\n",
"sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n",
"import tilelang\n",
"import torch\n",
...
...
@@ -61,7 +62,7 @@
" out_dtype: T.dtype = T.float32,\n",
" block_M: int = 128,\n",
" block_N: int = 128,\n",
" block_K: int = 32\n",
" block_K: int = 32
,
\n",
"):\n",
" M, K = A.shape\n",
" K, N = B.shape\n",
...
...
@@ -94,8 +95,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm(A, B)\n",
"\n",
"# check output is correct\n",
...
...
@@ -118,8 +119,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 1024, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 1024, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm(A, B, block_M=64, block_N=64)"
]
},
...
...
@@ -218,8 +219,8 @@
"source": [
"@tilelang.lazy_jit\n",
"def gemm_dyn_K(\n",
" A: T.Tensor[[int, T.dyn[
'K'
]], T.float16], # noqa: F821\n",
" B: T.Tensor[[T.dyn[
'K'
], int], T.float16], # noqa: F821\n",
" A: T.Tensor[[int, T.dyn[
\"K\"
]], T.float16],
# noqa: F821\n",
" B: T.Tensor[[T.dyn[
\"K\"
], int], T.float16],
# noqa: F821\n",
"):\n",
" M, K = A.shape\n",
" K, N = B.shape\n",
...
...
@@ -265,8 +266,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm_dyn_K(A, B)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...
...
@@ -295,18 +296,17 @@
"source": [
"from typing import Any\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"def as_contingious(\n",
" A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]\n",
"):\n",
"def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):\n",
" M, N = A.shape\n",
" B = T.empty((M, N), A.dtype)\n",
" block_M = 128\n",
" block_N = 128\n",
" with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
" T.copy(\n",
" A[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N],\n",
" B[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N]\n",
" A[bx * block_M
: (bx + 1) * block_M, by * block_N
: (by + 1) * block_N],\n",
" B[bx * block_M
: (bx + 1) * block_M, by * block_N
: (by + 1) * block_N]
,
\n",
" )\n",
" return B"
]
...
...
@@ -318,7 +318,7 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 1024, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 1024, device=
\"
cuda
\"
)\n",
"B = as_contingious(A[::2, ::2])\n",
"B_ref = A[::2, ::2].contiguous()\n",
"torch.testing.assert_close(B, B_ref)"
...
...
@@ -370,8 +370,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm_ptr(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...
...
@@ -416,8 +416,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...
...
@@ -496,18 +496,20 @@
"source": [
"from itertools import product\n",
"\n",
"\n",
"def get_configs():\n",
" return [\n",
" {\n",
"
'A'
: T.Tensor((1024, 1024), T.float32),\n",
"
'B'
: T.Tensor((1024, 1024), T.float32),\n",
"
'
block_M
'
: block_M,\n",
"
'
block_N
'
: block_N,\n",
"
'
block_K
'
: block_K,\n",
"
\"A\"
: T.Tensor((1024, 1024), T.float32),\n",
"
\"B\"
: T.Tensor((1024, 1024), T.float32),\n",
"
\"
block_M
\"
: block_M,\n",
"
\"
block_N
\"
: block_N,\n",
"
\"
block_K
\"
: block_K,\n",
" }\n",
" for block_M, block_N, block_K in product([32, 64], repeat=3)\n",
" ]\n",
"\n",
"\n",
"gemm.par_compile(get_configs())"
]
},
...
...
@@ -581,6 +583,7 @@
"def macro_with_ref(x: T.Ref):\n",
" x = 1 # noqa: F841\n",
"\n",
"\n",
"@T.prim_func\n",
"def foo(x: T.Tensor((2,))):\n",
" with T.Kernel(1) as _:\n",
...
...
@@ -591,6 +594,7 @@
" idx = T.alloc_var(T.int32, 0)\n",
" macro_with_ref(x[idx])\n",
"\n",
"\n",
"foo"
]
},
...
...
@@ -616,7 +620,7 @@
" A: T.Tensor[[T.dyn], Any],\n",
" fn,\n",
"):\n",
" N, = A.shape\n",
"
(
N,
)
= A.shape\n",
" B = T.empty((N,), dtype=A.dtype)\n",
" block_N = 128\n",
" with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n",
...
...
@@ -624,6 +628,8 @@
" idx = bx * block_N + i\n",
" B[idx] = fn(A[idx])\n",
" return B\n",
"\n",
"\n",
"@T.macro\n",
"def add_one(x):\n",
" return x + 1"
...
...
@@ -636,7 +642,7 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, device=
'
cuda
'
)\n",
"A = torch.randn(1024, device=
\"
cuda
\"
)\n",
"B = element_wise(A, add_one)\n",
"B_ref = A + 1\n",
"torch.testing.assert_close(B, B_ref)"
...
...
@@ -670,10 +676,11 @@
" var = var * 3 + 1\n",
" n31(x * 3 + 1, var)\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"def foo(A: T.Tensor[[1], T.int32], n: int):\n",
" with T.Kernel(1) as _:\n",
" n31(n, A[0])
\n
"
" n31(n, A[0])"
]
},
{
...
...
@@ -694,7 +701,7 @@
}
],
"source": [
"A = torch.tensor([100], dtype=torch.int32, device=
'
cuda
'
)\n",
"A = torch.tensor([100], dtype=torch.int32, device=
\"
cuda
\"
)\n",
"foo(A, 5)\n",
"A"
]
...
...
@@ -745,12 +752,15 @@
"def sincos(x):\n",
" return T.sin(x), T.cos(x)\n",
"\n",
"\n",
"@T.prim_func\n",
"def foo():\n",
" with T.Kernel(32) as x:\n",
" s, c = sincos(x)\n",
" a = s + c # noqa: F841\n",
" b = s - c # noqa: F841\n",
"\n",
"\n",
"foo"
]
}
...
...
examples/linear_attention/example_linear_attn_bwd.py
View file @
29051439
...
...
@@ -13,20 +13,20 @@ from typing import Optional, Tuple
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
}
)
def
tl_fused_chunk_bwd_kernel
(
B
,
S
,
H
,
DK
,
DV
,
dtype
:
str
=
'
float16
'
,
dtype
:
str
=
"
float16
"
,
scale
:
float
=
None
,
)
->
torch
.
Tensor
:
if
scale
is
None
:
scale
=
DK
**-
0.5
accum_dtype
=
'
float
'
accum_dtype
=
"
float
"
chunk_size
=
64
BK
=
BV
=
64
# Set to 128 can be faster, but has some numerical differences with FLA
...
...
@@ -66,11 +66,13 @@ def tl_fused_chunk_bwd_kernel(
dh
=
T
.
alloc_fragment
([
BK
,
BV
],
accum_dtype
)
dh_shared
=
T
.
alloc_shared
([
BK
,
BV
],
dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
dq_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dq_shared
),
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
)
})
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
}
)
T
.
use_swizzle
(
10
)
T
.
clear
(
h
)
...
...
@@ -78,10 +80,9 @@ def tl_fused_chunk_bwd_kernel(
# Calculate dQ
for
i
in
T
.
Pipelined
(
0
,
NT
):
T
.
copy
(
K
[
i_b
,
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
V
[
i_b
,
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
],
v
)
T
.
copy
(
dO
[
i_b
,
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
],
do
)
T
.
copy
(
K
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
V
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
v
)
T
.
copy
(
dO
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
do
)
T
.
gemm
(
do
,
v
,
ds
,
transpose_B
=
True
,
clear_accum
=
True
)
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
chunk_size
):
...
...
@@ -94,29 +95,19 @@ def tl_fused_chunk_bwd_kernel(
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
BK
):
dq
[
row
,
col
]
*=
scale
T
.
copy
(
dq
,
dq_shared
)
T
.
atomic_add
(
dQ
[
i_b
,
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:(
i_k
+
1
)
*
BK
],
dq_shared
)
T
.
atomic_add
(
dQ
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
dq_shared
)
# Calculate dK, dV (reversely)
for
i
in
T
.
Pipelined
(
1
,
NT
+
1
):
start
=
NT
-
i
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
BK
):
q
[
row
,
col
]
=
Q
[
i_b
,
start
*
chunk_size
+
row
,
i_h
,
i_k
*
BK
+
col
]
*
scale
T
.
copy
(
K
[
i_b
,
start
*
chunk_size
:(
start
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
V
[
i_b
,
start
*
chunk_size
:(
start
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
],
v
)
T
.
copy
(
dO
[
i_b
,
start
*
chunk_size
:(
start
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
],
do
)
T
.
copy
(
K
[
i_b
,
start
*
chunk_size
:
(
start
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
V
[
i_b
,
start
*
chunk_size
:
(
start
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
v
)
T
.
copy
(
dO
[
i_b
,
start
*
chunk_size
:
(
start
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
do
)
# Calculate dk
T
.
gemm
(
v
,
do
,
ds
,
transpose_B
=
True
,
clear_accum
=
True
)
# ds here actually means `s`, but we simply reuse the buffer `ds`
T
.
gemm
(
v
,
do
,
ds
,
transpose_B
=
True
,
clear_accum
=
True
)
# ds here actually means `s`, but we simply reuse the buffer `ds`
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
chunk_size
):
ds_shared
[
row
,
col
]
=
T
.
if_then_else
(
row
<=
col
,
ds
[
row
,
col
],
0
)
T
.
gemm
(
ds_shared
,
q
,
dk
,
clear_accum
=
True
)
...
...
@@ -134,13 +125,9 @@ def tl_fused_chunk_bwd_kernel(
T
.
gemm
(
q
,
do
,
dh
,
transpose_A
=
True
)
T
.
copy
(
dk
,
dk_shared
)
T
.
atomic_add
(
dK
[
i_b
,
start
*
chunk_size
:(
start
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:(
i_k
+
1
)
*
BK
],
dk_shared
)
T
.
atomic_add
(
dK
[
i_b
,
start
*
chunk_size
:
(
start
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
dk_shared
)
T
.
copy
(
dv
,
dv_shared
)
T
.
atomic_add
(
dV
[
i_b
,
start
*
chunk_size
:(
start
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
],
dv_shared
)
T
.
atomic_add
(
dV
[
i_b
,
start
*
chunk_size
:
(
start
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
dv_shared
)
return
fused_chunk_linear_attn_bwd
...
...
@@ -155,34 +142,31 @@ def tl_fused_chunk_bwd(Q, K, V, dO):
return
dQ
.
to
(
torch
.
float16
),
dK
.
to
(
torch
.
float16
),
dV
.
to
(
torch
.
float16
)
def
ref_program
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
scale
:
Optional
[
float
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
ref_program
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
scale
:
Optional
[
float
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
if
scale
is
None
:
scale
=
q
.
shape
[
-
1
]
**-
0.5
scale
=
q
.
shape
[
-
1
]
**
-
0.5
chunk_size
=
64
q
=
rearrange
(
q
,
'
b (n c) h d -> b h n c d
'
,
c
=
chunk_size
)
*
scale
k
=
rearrange
(
k
,
'
b (n c) h d -> b h n c d
'
,
c
=
chunk_size
)
v
=
rearrange
(
v
,
'
b (n c) h d -> b h n c d
'
,
c
=
chunk_size
)
q
=
rearrange
(
q
,
"
b (n c) h d -> b h n c d
"
,
c
=
chunk_size
)
*
scale
k
=
rearrange
(
k
,
"
b (n c) h d -> b h n c d
"
,
c
=
chunk_size
)
v
=
rearrange
(
v
,
"
b (n c) h d -> b h n c d
"
,
c
=
chunk_size
)
kv
=
k
.
transpose
(
-
1
,
-
2
)
@
v
kv
=
kv
.
cumsum
(
2
)
h
=
kv
[:,
:,
-
1
,
:,
:]
kv
=
torch
.
cat
([
torch
.
zeros_like
(
kv
[:,
:,
:
1
]),
kv
[:,
:,
:
-
1
]],
dim
=
2
)
inter
=
q
@
kv
intra
=
(
(
q
@
k
.
transpose
(
-
1
,
-
2
)).
masked_fill_
(
torch
.
triu
(
torch
.
ones
(
chunk_size
,
chunk_size
,
dtype
=
bool
,
device
=
q
.
device
),
diagonal
=
1
),
0
)
)
@
v
intra
=
(
(
q
@
k
.
transpose
(
-
1
,
-
2
)).
masked_fill_
(
torch
.
triu
(
torch
.
ones
(
chunk_size
,
chunk_size
,
dtype
=
bool
,
device
=
q
.
device
),
diagonal
=
1
),
0
)
)
@
v
o
=
inter
+
intra
return
rearrange
(
o
,
'
b h n c d -> b (n c) h d
'
),
h
return
rearrange
(
o
,
"
b h n c d -> b (n c) h d
"
),
h
def
main
(
B
=
1
,
S
=
1024
,
H
=
16
,
D
=
128
):
q
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
,
requires_grad
=
True
)
k
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
,
requires_grad
=
True
)
v
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
,
requires_grad
=
True
)
do
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
,
requires_grad
=
True
)
k
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
,
requires_grad
=
True
)
v
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
,
requires_grad
=
True
)
do
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
# qk norm is necessary for linear attn
q
=
l2norm_fwd
(
q
)[
0
].
requires_grad_
(
True
)
...
...
@@ -193,30 +177,27 @@ def main(B=1, S=1024, H=16, D=128):
o_ref
,
_
=
ref_program
(
q
,
k
,
v
)
o_ref
.
backward
(
do
,
retain_graph
=
True
)
assert
torch
.
allclose
(
dq
,
q
.
grad
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
'dq max err:
{
(
dq
-
q
.
grad
).
abs
().
max
()
}
'
assert
torch
.
allclose
(
dk
,
k
.
grad
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
'dk max err:
{
(
dk
-
k
.
grad
).
abs
().
max
()
}
'
assert
torch
.
allclose
(
dv
,
v
.
grad
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
'dv max err:
{
(
dv
-
v
.
grad
).
abs
().
max
()
}
'
print
(
'Passed all tests!✅'
)
assert
torch
.
allclose
(
dq
,
q
.
grad
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"dq max err:
{
(
dq
-
q
.
grad
).
abs
().
max
()
}
"
assert
torch
.
allclose
(
dk
,
k
.
grad
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"dk max err:
{
(
dk
-
k
.
grad
).
abs
().
max
()
}
"
assert
torch
.
allclose
(
dv
,
v
.
grad
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"dv max err:
{
(
dv
-
v
.
grad
).
abs
().
max
()
}
"
print
(
"Passed all tests!✅"
)
# Benchmark
q
.
grad
=
k
.
grad
=
v
.
grad
=
None
o_ref
,
_
=
fused_chunk_linear_attn
(
q
,
k
,
v
,
output_final_state
=
True
,
normalize
=
False
)
t1
=
do_bench
(
lambda
:
o_ref
.
backward
(
do
,
retain_graph
=
True
),
backend
=
'
cupti
'
)
t2
=
do_bench
(
lambda
:
tl_fused_chunk_bwd
(
q
,
k
,
v
,
do
),
backend
=
'
cupti
'
)
print
(
f
'
Triton latency:
{
t1
:.
3
f
}
ms
'
)
print
(
f
'
TileLang latency:
{
t2
:.
3
f
}
ms
'
)
print
(
f
'
Speedup:
{
t1
/
t2
:.
3
f
}
x
'
)
t1
=
do_bench
(
lambda
:
o_ref
.
backward
(
do
,
retain_graph
=
True
),
backend
=
"
cupti
"
)
t2
=
do_bench
(
lambda
:
tl_fused_chunk_bwd
(
q
,
k
,
v
,
do
),
backend
=
"
cupti
"
)
print
(
f
"
Triton latency:
{
t1
:.
3
f
}
ms
"
)
print
(
f
"
TileLang latency:
{
t2
:.
3
f
}
ms
"
)
print
(
f
"
Speedup:
{
t1
/
t2
:.
3
f
}
x
"
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--B
'
,
type
=
int
,
default
=
8
,
help
=
'
Batch size
'
)
parser
.
add_argument
(
'
--S
'
,
type
=
int
,
default
=
1024
,
help
=
'
Seq len
'
)
parser
.
add_argument
(
'
--H
'
,
type
=
int
,
default
=
32
,
help
=
'
Num heads
'
)
parser
.
add_argument
(
'
--D
'
,
type
=
int
,
default
=
128
,
help
=
'
Head dim
'
)
parser
.
add_argument
(
"
--B
"
,
type
=
int
,
default
=
8
,
help
=
"
Batch size
"
)
parser
.
add_argument
(
"
--S
"
,
type
=
int
,
default
=
1024
,
help
=
"
Seq len
"
)
parser
.
add_argument
(
"
--H
"
,
type
=
int
,
default
=
32
,
help
=
"
Num heads
"
)
parser
.
add_argument
(
"
--D
"
,
type
=
int
,
default
=
128
,
help
=
"
Head dim
"
)
args
=
parser
.
parse_args
()
main
(
args
.
B
,
args
.
S
,
args
.
H
,
args
.
D
)
examples/linear_attention/example_linear_attn_fwd.py
View file @
29051439
...
...
@@ -14,20 +14,20 @@ from typing import Optional, Tuple
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
def
tl_fused_chunk_fwd_kernel
(
B
,
S
,
H
,
DK
,
DV
,
dtype
:
str
=
'
float16
'
,
dtype
:
str
=
"
float16
"
,
scale
:
float
=
None
,
)
->
torch
.
Tensor
:
if
scale
is
None
:
scale
=
DK
**-
0.5
accum_dtype
=
'
float
'
accum_dtype
=
"
float
"
chunk_size
=
64
BK
=
BV
=
64
# Set to 128 can be faster, but has some numerical differences with FLA
...
...
@@ -42,7 +42,8 @@ def tl_fused_chunk_fwd_kernel(
K
:
T
.
Tensor
([
B
,
S
,
H
,
DK
],
dtype
),
# type: ignore
V
:
T
.
Tensor
([
B
,
S
,
H
,
DV
],
dtype
),
# type: ignore
O
:
T
.
Tensor
([
B
,
S
,
H
,
DV
],
accum_dtype
),
# type: ignore
final_state
:
T
.
Tensor
([
B
,
H
,
DK
,
DV
],
accum_dtype
)):
# type: ignore
final_state
:
T
.
Tensor
([
B
,
H
,
DK
,
DV
],
accum_dtype
),
):
# type: ignore
with
T
.
Kernel
(
NV
,
NK
,
B
*
H
)
as
(
i_v
,
i_k
,
i_bh
):
i_b
=
i_bh
//
H
i_h
=
i_bh
%
H
...
...
@@ -65,8 +66,8 @@ def tl_fused_chunk_fwd_kernel(
for
i
in
T
.
Pipelined
(
0
,
NT
):
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
BK
):
q
[
row
,
col
]
=
Q
[
i_b
,
i
*
chunk_size
+
row
,
i_h
,
i_k
*
BK
+
col
]
*
scale
T
.
copy
(
K
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
V
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
v
)
T
.
copy
(
K
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
V
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
v
)
T
.
gemm
(
q
,
k
,
s
,
clear_accum
=
True
,
transpose_B
=
True
)
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
chunk_size
):
...
...
@@ -77,12 +78,10 @@ def tl_fused_chunk_fwd_kernel(
T
.
gemm
(
k
,
v
,
h
,
transpose_A
=
True
)
T
.
gemm
(
q
,
h_shared
,
o
)
T
.
copy
(
o
,
o_shared
)
T
.
atomic_add
(
O
[
i_b
,
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
],
o_shared
)
T
.
atomic_add
(
O
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
o_shared
)
# Output final state
T
.
copy
(
h
,
final_state
[
i_b
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
])
T
.
copy
(
h
,
final_state
[
i_b
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
])
return
fused_chunk_linear_attn_fwd
...
...
@@ -91,38 +90,35 @@ def tl_fused_chunk_fwd(q, k, v):
B
,
S
,
H
,
D
=
q
.
shape
kernel
=
tl_fused_chunk_fwd_kernel
(
B
,
S
,
H
,
D
,
D
)
print
(
kernel
.
get_kernel_source
())
o
=
torch
.
zeros
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float32
)
o
=
torch
.
zeros
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float32
)
h
=
kernel
(
q
,
k
,
v
,
o
)
return
o
,
h
def
ref_program
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
scale
:
Optional
[
float
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
ref_program
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
scale
:
Optional
[
float
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
if
scale
is
None
:
scale
=
q
.
shape
[
-
1
]
**-
0.5
scale
=
q
.
shape
[
-
1
]
**
-
0.5
chunk_size
=
64
q
=
rearrange
(
q
,
'
b (n c) h d -> b h n c d
'
,
c
=
chunk_size
)
*
scale
k
=
rearrange
(
k
,
'
b (n c) h d -> b h n c d
'
,
c
=
chunk_size
)
v
=
rearrange
(
v
,
'
b (n c) h d -> b h n c d
'
,
c
=
chunk_size
)
q
=
rearrange
(
q
,
"
b (n c) h d -> b h n c d
"
,
c
=
chunk_size
)
*
scale
k
=
rearrange
(
k
,
"
b (n c) h d -> b h n c d
"
,
c
=
chunk_size
)
v
=
rearrange
(
v
,
"
b (n c) h d -> b h n c d
"
,
c
=
chunk_size
)
kv
=
k
.
transpose
(
-
1
,
-
2
)
@
v
kv
=
kv
.
cumsum
(
2
)
h
=
kv
[:,
:,
-
1
,
:,
:]
kv
=
torch
.
cat
([
torch
.
zeros_like
(
kv
[:,
:,
:
1
]),
kv
[:,
:,
:
-
1
]],
dim
=
2
)
inter
=
q
@
kv
intra
=
(
(
q
@
k
.
transpose
(
-
1
,
-
2
)).
masked_fill_
(
torch
.
triu
(
torch
.
ones
(
chunk_size
,
chunk_size
,
dtype
=
bool
,
device
=
q
.
device
),
diagonal
=
1
),
0
)
)
@
v
intra
=
(
(
q
@
k
.
transpose
(
-
1
,
-
2
)).
masked_fill_
(
torch
.
triu
(
torch
.
ones
(
chunk_size
,
chunk_size
,
dtype
=
bool
,
device
=
q
.
device
),
diagonal
=
1
),
0
)
)
@
v
o
=
inter
+
intra
return
rearrange
(
o
,
'
b h n c d -> b (n c) h d
'
),
h
return
rearrange
(
o
,
"
b h n c d -> b (n c) h d
"
),
h
def
main
(
B
=
1
,
S
=
512
,
H
=
16
,
D
=
128
):
q
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
# qk norm is necessary for linear attn
q
,
_
=
l2norm_fwd
(
q
)
...
...
@@ -131,25 +127,23 @@ def main(B=1, S=512, H=16, D=128):
o
,
h
=
tl_fused_chunk_fwd
(
q
,
k
,
v
)
o_ref
,
h_ref
=
ref_program
(
q
,
k
,
v
)
assert
torch
.
allclose
(
o
,
o_ref
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
'
o max err:
{
(
o
-
o_ref
).
abs
().
max
()
}
'
assert
torch
.
allclose
(
h
,
h_ref
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
'
h max err:
{
(
h
-
h_ref
).
abs
().
max
()
}
'
print
(
'
Passed all tests!✅
'
)
assert
torch
.
allclose
(
o
,
o_ref
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"
o max err:
{
(
o
-
o_ref
).
abs
().
max
()
}
"
assert
torch
.
allclose
(
h
,
h_ref
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"
h max err:
{
(
h
-
h_ref
).
abs
().
max
()
}
"
print
(
"
Passed all tests!✅
"
)
t1
=
do_bench
(
lambda
:
fused_chunk_linear_attn
(
q
,
k
,
v
,
output_final_state
=
True
,
normalize
=
False
),
backend
=
'cupti'
)
t2
=
do_bench
(
lambda
:
tl_fused_chunk_fwd
(
q
,
k
,
v
),
backend
=
'cupti'
)
print
(
f
'Triton latency:
{
t1
:.
3
f
}
ms'
)
print
(
f
'TileLang latency:
{
t2
:.
3
f
}
ms'
)
print
(
f
'Speedup:
{
t1
/
t2
:.
3
f
}
x'
)
t1
=
do_bench
(
lambda
:
fused_chunk_linear_attn
(
q
,
k
,
v
,
output_final_state
=
True
,
normalize
=
False
),
backend
=
"cupti"
)
t2
=
do_bench
(
lambda
:
tl_fused_chunk_fwd
(
q
,
k
,
v
),
backend
=
"cupti"
)
print
(
f
"Triton latency:
{
t1
:.
3
f
}
ms"
)
print
(
f
"TileLang latency:
{
t2
:.
3
f
}
ms"
)
print
(
f
"Speedup:
{
t1
/
t2
:.
3
f
}
x"
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--B
'
,
type
=
int
,
default
=
8
,
help
=
'
Batch size
'
)
parser
.
add_argument
(
'
--S
'
,
type
=
int
,
default
=
1024
,
help
=
'
Seq len
'
)
parser
.
add_argument
(
'
--H
'
,
type
=
int
,
default
=
32
,
help
=
'
Num heads
'
)
parser
.
add_argument
(
'
--D
'
,
type
=
int
,
default
=
128
,
help
=
'
Head dim
'
)
parser
.
add_argument
(
"
--B
"
,
type
=
int
,
default
=
8
,
help
=
"
Batch size
"
)
parser
.
add_argument
(
"
--S
"
,
type
=
int
,
default
=
1024
,
help
=
"
Seq len
"
)
parser
.
add_argument
(
"
--H
"
,
type
=
int
,
default
=
32
,
help
=
"
Num heads
"
)
parser
.
add_argument
(
"
--D
"
,
type
=
int
,
default
=
128
,
help
=
"
Head dim
"
)
args
=
parser
.
parse_args
()
main
(
args
.
B
,
args
.
S
,
args
.
H
,
args
.
D
)
examples/linear_attention/example_mamba_chunk_scan.py
View file @
29051439
...
...
@@ -9,6 +9,7 @@ import itertools
def
chunk_scan_triton
(
cb
,
x
,
dt
,
dA_cumsum
,
C
,
states
,
D
):
from
mamba_ssm.ops.triton.ssd_chunk_scan
import
_chunk_scan_fwd
out
,
_
=
_chunk_scan_fwd
(
cb
,
x
,
dt
,
dA_cumsum
,
C
,
states
,
D
)
return
out
...
...
@@ -43,14 +44,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
dt_segment_sum
=
dA_cumsum
[:,
:,
:,
:,
None
]
-
dA_cumsum
[:,
:,
:,
None
,
:]
decay
=
torch
.
exp
(
dt_segment_sum
)
scores_decay
=
cb
*
rearrange
(
decay
,
"b h c l s -> b c h l s"
)
causal_mask
=
torch
.
tril
(
torch
.
ones
(
chunk_size
,
chunk_size
,
device
=
x
.
device
,
dtype
=
bool
),
diagonal
=
0
)
causal_mask
=
torch
.
tril
(
torch
.
ones
(
chunk_size
,
chunk_size
,
device
=
x
.
device
,
dtype
=
bool
),
diagonal
=
0
)
scores_decay
=
scores_decay
.
masked_fill
(
~
causal_mask
,
0
)
out
=
torch
.
einsum
(
'bchls,bhcs,bcshp->bclhp'
,
scores_decay
.
to
(
x
.
dtype
),
dt
.
to
(
x
.
dtype
),
rearrange
(
x
,
"b (c s) h p -> b c s h p"
,
c
=
nchunks
))
out
=
torch
.
einsum
(
"bchls,bhcs,bcshp->bclhp"
,
scores_decay
.
to
(
x
.
dtype
),
dt
.
to
(
x
.
dtype
),
rearrange
(
x
,
"b (c s) h p -> b c s h p"
,
c
=
nchunks
)
)
state_decay_out
=
torch
.
exp
(
rearrange
(
dA_cumsum
,
"b h c l -> b c l h 1"
))
out_prev
=
torch
.
einsum
(
'bclhn,bchpn->bclhp'
,
rearrange
(
C
,
"b (c l) h n -> b c l h n"
,
c
=
nchunks
),
prev_states
.
to
(
C
.
dtype
))
*
state_decay_out
out_prev
=
(
torch
.
einsum
(
"bclhn,bchpn->bclhp"
,
rearrange
(
C
,
"b (c l) h n -> b c l h n"
,
c
=
nchunks
),
prev_states
.
to
(
C
.
dtype
))
*
state_decay_out
)
out
=
out
+
out_prev
out
=
rearrange
(
out
,
"b c l h p -> b (c l) h p"
)
if
D
is
not
None
:
...
...
@@ -61,12 +63,7 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
def
get_configs
():
iter_params
=
dict
(
block_M
=
[
64
,
128
,
256
],
block_N
=
[
32
,
64
],
block_K
=
[
64
,
128
,
256
],
block_Dstate
=
[
128
],
num_stages
=
[
1
,
2
,
3
,
4
,
5
])
iter_params
=
dict
(
block_M
=
[
64
,
128
,
256
],
block_N
=
[
32
,
64
],
block_K
=
[
64
,
128
,
256
],
block_Dstate
=
[
128
],
num_stages
=
[
1
,
2
,
3
,
4
,
5
])
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
...
...
@@ -77,7 +74,8 @@ def get_configs():
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
},
)
def
chunk_scan_fwd
(
batch
,
def
chunk_scan_fwd
(
batch
,
seqlen
,
chunk_size
,
ngroups
,
...
...
@@ -89,7 +87,8 @@ def chunk_scan_fwd(batch,
block_K
=
64
,
block_Dstate
=
128
,
num_stages
=
2
,
threads
=
128
):
threads
=
128
,
):
dtype
=
"float16"
accum_dtype
=
"float"
nchunks
=
T
.
ceildiv
(
seqlen
,
chunk_size
)
...
...
@@ -104,13 +103,13 @@ def chunk_scan_fwd(batch,
C
:
T
.
Tensor
((
batch
,
seqlen
,
ngroups
,
dstate
),
dtype
),
# type: ignore
prev_states
:
T
.
Tensor
((
batch
,
nchunks
,
nheads
,
headdim
,
dstate
),
dtype
),
# type: ignore
D
:
T
.
Tensor
((
nheads
),
dtype
),
# type: ignore
Output
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
)
# type: ignore
Output
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
),
# type: ignore
):
with
T
.
Kernel
(
nheads
,
T
.
ceildiv
(
chunk_size
,
block_M
)
*
T
.
ceildiv
(
headdim
,
block_N
),
batch
*
nchunks
,
threads
=
threads
)
as
(
bz
,
bx
,
by
,
):
with
T
.
Kernel
(
nheads
,
T
.
ceildiv
(
chunk_size
,
block_M
)
*
T
.
ceildiv
(
headdim
,
block_N
),
batch
*
nchunks
,
threads
=
threads
)
as
(
bz
,
bx
,
by
):
acc_o
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
acc_o_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
cb_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
,
scope
=
"shared.dyn"
)
...
...
@@ -136,27 +135,32 @@ def chunk_scan_fwd(batch,
m_idx
=
bx
//
T
.
ceildiv
(
headdim
,
block_N
)
n_idx
=
bx
%
T
.
ceildiv
(
headdim
,
block_N
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
acc_o_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
acc_o_shared
),
cb_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
cb_shared
),
x_residual_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
x_residual_shared
)
})
x_residual_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
x_residual_shared
),
}
)
T
.
no_set_max_nreg
()
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
m_idx
*
block_M
:(
m_idx
+
1
)
*
block_M
],
dA_cs_m_shared
)
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
m_idx
*
block_M
:
(
m_idx
+
1
)
*
block_M
],
dA_cs_m_shared
)
T
.
copy
(
dA_cs_m_shared
,
dA_cs_m_local
)
T
.
clear
(
acc_o
)
for
i
in
T
.
Parallel
(
block_M
):
scale_m_local
[
i
]
=
T
.
exp2
(
dA_cs_m_local
[
i
]
*
p
)
T
.
copy
(
C
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
//
(
nheads
//
ngroups
),
0
:
block_Dstate
],
C_shared
)
T
.
copy
(
prev_states
[
batch_idx
,
chunk_idx
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
,
0
:
block_Dstate
],
prev_state_shared
)
C
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
//
(
nheads
//
ngroups
),
0
:
block_Dstate
,
],
C_shared
,
)
T
.
copy
(
prev_states
[
batch_idx
,
chunk_idx
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
0
:
block_Dstate
],
prev_state_shared
)
T
.
gemm
(
C_shared
,
prev_state_shared
,
acc_o
,
transpose_B
=
True
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_o
[
i
,
j
]
*=
scale_m_local
[
i
]
...
...
@@ -165,34 +169,47 @@ def chunk_scan_fwd(batch,
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
cb
[
batch_idx
,
chunk_idx
,
bz
//
(
nheads
//
ngroups
),
m_idx
*
block_M
:(
m_idx
+
1
)
*
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
cb_shared
)
cb
[
batch_idx
,
chunk_idx
,
bz
//
(
nheads
//
ngroups
),
m_idx
*
block_M
:
(
m_idx
+
1
)
*
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
,
],
cb_shared
,
)
T
.
copy
(
cb_shared
,
cb_local
)
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
dA_cs_k_shared
)
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
dA_cs_k_shared
)
T
.
copy
(
dA_cs_k_shared
,
dA_cs_k_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
cb_local
[
i
,
j
]
=
cb_local
[
i
,
j
]
*
T
.
exp2
(
dA_cs_m_local
[
i
]
*
p
-
dA_cs_k_local
[
j
]
*
p
)
T
.
copy
(
dt
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
dt_shared
)
cb_local
[
i
,
j
]
=
cb_local
[
i
,
j
]
*
T
.
exp2
(
dA_cs_m_local
[
i
]
*
p
-
dA_cs_k_local
[
j
]
*
p
)
T
.
copy
(
dt
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
dt_shared
)
T
.
copy
(
dt_shared
,
dt_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
cb_local
[
i
,
j
]
*=
dt_local
[
j
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
cb_local
[
i
,
j
]
=
T
.
if_then_else
(
m_idx
*
block_M
+
i
>=
k
*
block_K
+
j
,
cb_local
[
i
,
j
],
0
)
cb_local
[
i
,
j
]
=
T
.
if_then_else
(
m_idx
*
block_M
+
i
>=
k
*
block_K
+
j
,
cb_local
[
i
,
j
],
0
)
T
.
copy
(
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
(
k
+
1
)
*
block_K
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
],
x_shared
)
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
(
k
+
1
)
*
block_K
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
],
x_shared
,
)
T
.
gemm
(
cb_local
,
x_shared
,
acc_o
)
D_local
[
0
]
=
D
[
bz
]
T
.
copy
(
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
],
x_residual_shared
)
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
],
x_residual_shared
,
)
T
.
copy
(
x_residual_shared
,
x_residual_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_o
[
i
,
j
]
+=
x_residual_local
[
i
,
j
]
*
D_local
[
0
]
...
...
@@ -200,27 +217,40 @@ def chunk_scan_fwd(batch,
T
.
copy
(
acc_o
,
acc_o_shared
)
T
.
copy
(
acc_o_shared
,
Output
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
])
Output
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
],
)
return
main
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
8
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
80
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--groups
'
,
type
=
int
,
default
=
1
,
help
=
'
groups
'
)
parser
.
add_argument
(
'
--seq_len
'
,
type
=
int
,
default
=
4096
,
help
=
'
sequence length
'
)
parser
.
add_argument
(
'
--chunk_size
'
,
type
=
int
,
default
=
256
,
help
=
'
chunk size
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
64
,
help
=
'
dim
'
)
parser
.
add_argument
(
'
--dstate
'
,
type
=
int
,
default
=
128
,
help
=
'
dstate
'
)
parser
.
add_argument
(
'
--tune
'
,
action
=
'
store_true
'
,
help
=
'
tune configs
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
8
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
80
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--groups
"
,
type
=
int
,
default
=
1
,
help
=
"
groups
"
)
parser
.
add_argument
(
"
--seq_len
"
,
type
=
int
,
default
=
4096
,
help
=
"
sequence length
"
)
parser
.
add_argument
(
"
--chunk_size
"
,
type
=
int
,
default
=
256
,
help
=
"
chunk size
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
64
,
help
=
"
dim
"
)
parser
.
add_argument
(
"
--dstate
"
,
type
=
int
,
default
=
128
,
help
=
"
dstate
"
)
parser
.
add_argument
(
"
--tune
"
,
action
=
"
store_true
"
,
help
=
"
tune configs
"
)
args
=
parser
.
parse_args
()
batch
,
heads
,
groups
,
seq_len
,
chunk_size
,
dim
,
dstate
=
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
seq_len
,
args
.
chunk_size
,
args
.
dim
,
args
.
dstate
batch
,
heads
,
groups
,
seq_len
,
chunk_size
,
dim
,
dstate
=
(
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
seq_len
,
args
.
chunk_size
,
args
.
dim
,
args
.
dstate
,
)
total_flops
=
2
*
batch
*
seq_len
*
chunk_size
*
heads
*
dim
*
0.5
+
2
*
batch
*
seq_len
*
heads
*
dim
*
dstate
if
(
not
args
.
tune
)
:
if
not
args
.
tune
:
kernel
=
chunk_scan_fwd
(
batch
,
seq_len
,
...
...
@@ -234,7 +264,8 @@ if __name__ == "__main__":
block_K
=
64
,
block_Dstate
=
128
,
num_stages
=
2
,
threads
=
128
)
threads
=
128
,
)
profiler
=
kernel
.
get_profiler
(
tilelang
.
TensorSupplyType
.
Normal
)
profiler
.
assert_allclose
(
ref_program
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"All checks pass."
)
...
...
examples/linear_attention/example_mamba_chunk_state.py
View file @
29051439
...
...
@@ -10,6 +10,7 @@ import itertools
def
chunk_state_triton
(
B
,
x
,
dt
,
dA_cumsum
):
from
mamba_ssm.ops.triton.ssd_chunk_state
import
_chunk_state_fwd
return
_chunk_state_fwd
(
B
,
x
,
dt
,
dA_cumsum
,
states_in_fp32
=
False
)
...
...
@@ -41,46 +42,33 @@ def ref_program(B, x, dt, dA_cumsum):
x
=
rearrange
(
x
,
"b (c l) h p -> b c l h p"
,
l
=
chunk_size
)
B
=
rearrange
(
B
,
"b (c l) ... -> b c l ..."
,
l
=
chunk_size
)
decay_states
=
torch
.
exp
((
dA_cumsum
[:,
:,
:,
-
1
:]
-
dA_cumsum
))
return
torch
.
einsum
(
"bclhn,bhcl,bhcl,bclhp->bchpn"
,
B
.
to
(
x
.
dtype
),
decay_states
.
to
(
x
.
dtype
),
dt
.
to
(
x
.
dtype
),
x
)
return
torch
.
einsum
(
"bclhn,bhcl,bhcl,bclhp->bchpn"
,
B
.
to
(
x
.
dtype
),
decay_states
.
to
(
x
.
dtype
),
dt
.
to
(
x
.
dtype
),
x
)
def
get_configs
():
iter_params
=
dict
(
block_M
=
[
64
,
128
],
block_N
=
[
32
,
64
,
128
],
block_K
=
[
32
,
64
],
num_stages
=
[
1
,
2
,
3
,
4
,
5
])
iter_params
=
dict
(
block_M
=
[
64
,
128
],
block_N
=
[
32
,
64
,
128
],
block_K
=
[
32
,
64
],
num_stages
=
[
1
,
2
,
3
,
4
,
5
])
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
4
])
def
chunk_state_fwd
(
batch
,
seqlen
,
chunk_size
,
ngroups
,
nheads
,
headdim
,
dstate
,
block_M
=
64
,
block_N
=
64
,
block_K
=
64
,
num_stages
=
2
,
threads
=
128
):
def
chunk_state_fwd
(
batch
,
seqlen
,
chunk_size
,
ngroups
,
nheads
,
headdim
,
dstate
,
block_M
=
64
,
block_N
=
64
,
block_K
=
64
,
num_stages
=
2
,
threads
=
128
):
dtype
=
"float16"
accum_dtype
=
"float"
nchunks
=
T
.
ceildiv
(
seqlen
,
chunk_size
)
p
=
1.44269504
@
T
.
prim_func
def
main
(
B
:
T
.
Tensor
((
batch
,
seqlen
,
ngroups
,
dstate
),
dtype
),
x
:
T
.
Tensor
(
(
batch
,
seqlen
,
nheads
,
headdim
),
dtype
),
dt
:
T
.
Tensor
(
(
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
dA_cumsum
:
T
.
Tensor
(
(
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
Output
:
T
.
Tensor
(
(
batch
,
nchunks
,
nheads
,
headdim
,
dstate
),
dtype
)):
with
T
.
Kernel
(
nheads
,
T
.
ceildiv
(
headdim
,
block_M
)
*
T
.
ceildiv
(
dstate
,
block_N
),
batch
*
nchunks
,
threads
=
threads
)
as
(
bz
,
bx
,
by
):
def
main
(
B
:
T
.
Tensor
((
batch
,
seqlen
,
ngroups
,
dstate
),
dtype
),
x
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
),
dt
:
T
.
Tensor
((
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
dA_cumsum
:
T
.
Tensor
((
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
Output
:
T
.
Tensor
((
batch
,
nchunks
,
nheads
,
headdim
,
dstate
),
dtype
),
):
with
T
.
Kernel
(
nheads
,
T
.
ceildiv
(
headdim
,
block_M
)
*
T
.
ceildiv
(
dstate
,
block_N
),
batch
*
nchunks
,
threads
=
threads
)
as
(
bz
,
bx
,
by
):
x_shared
=
T
.
alloc_shared
((
block_K
,
block_M
),
dtype
)
x_local
=
T
.
alloc_fragment
((
block_K
,
block_M
),
dtype
)
xt_local
=
T
.
alloc_fragment
((
block_M
,
block_K
),
dtype
)
...
...
@@ -101,20 +89,24 @@ def chunk_state_fwd(batch,
m_idx
=
bx
//
T
.
ceildiv
(
dstate
,
block_N
)
n_idx
=
bx
%
T
.
ceildiv
(
dstate
,
block_N
)
T
.
annotate_layout
({
x_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
x_shared
),
acc_o_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
acc_o_shared
)
})
T
.
annotate_layout
(
{
x_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
x_shared
),
acc_o_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
acc_o_shared
)}
)
dA_cs_last
[
0
]
=
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
chunk_size
-
1
]
T
.
clear
(
acc_o
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
(
k
+
1
)
*
block_K
,
bz
,
m_idx
*
block_M
:(
m_idx
+
1
)
*
block_M
],
x_shared
)
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
dA_cumsum_shared
)
T
.
copy
(
dt
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
dt_shared
)
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
(
k
+
1
)
*
block_K
,
bz
,
m_idx
*
block_M
:
(
m_idx
+
1
)
*
block_M
,
],
x_shared
,
)
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
dA_cumsum_shared
)
T
.
copy
(
dt
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
dt_shared
)
T
.
copy
(
dA_cumsum_shared
,
dA_cumsum_local
)
T
.
copy
(
dt_shared
,
dt_local
)
for
i
in
T
.
Parallel
(
block_K
):
...
...
@@ -123,47 +115,50 @@ def chunk_state_fwd(batch,
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
xt_local
[
i
,
j
]
=
x_local
[
j
,
i
]
*
scale
[
j
]
T
.
copy
(
B
[
batch_idx
,
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
(
k
+
1
)
*
block_K
,
bz
//
(
nheads
//
ngroups
),
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
],
B_shared
)
B
[
batch_idx
,
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
(
k
+
1
)
*
block_K
,
bz
//
(
nheads
//
ngroups
),
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
],
B_shared
,
)
T
.
gemm
(
xt_local
,
B_shared
,
acc_o
)
T
.
copy
(
acc_o
,
acc_o_shared
)
T
.
copy
(
acc_o_shared
,
Output
[
batch_idx
,
chunk_idx
,
bz
,
m_idx
*
block_M
:
(
m_idx
+
1
)
*
block_M
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
]
)
Output
[
batch_idx
,
chunk_idx
,
bz
,
m_idx
*
block_M
:
(
m_idx
+
1
)
*
block_M
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
],
)
return
main
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
8
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
80
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--groups
'
,
type
=
int
,
default
=
1
,
help
=
'
groups
'
)
parser
.
add_argument
(
'
--seq_len
'
,
type
=
int
,
default
=
4096
,
help
=
'
sequence length
'
)
parser
.
add_argument
(
'
--chunk_size
'
,
type
=
int
,
default
=
256
,
help
=
'
chunk size
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
64
,
help
=
'
dim
'
)
parser
.
add_argument
(
'
--dstate
'
,
type
=
int
,
default
=
128
,
help
=
'
dstate
'
)
parser
.
add_argument
(
'
--tune
'
,
action
=
'
store_true
'
,
help
=
'
tune configs
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
8
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
80
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--groups
"
,
type
=
int
,
default
=
1
,
help
=
"
groups
"
)
parser
.
add_argument
(
"
--seq_len
"
,
type
=
int
,
default
=
4096
,
help
=
"
sequence length
"
)
parser
.
add_argument
(
"
--chunk_size
"
,
type
=
int
,
default
=
256
,
help
=
"
chunk size
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
64
,
help
=
"
dim
"
)
parser
.
add_argument
(
"
--dstate
"
,
type
=
int
,
default
=
128
,
help
=
"
dstate
"
)
parser
.
add_argument
(
"
--tune
"
,
action
=
"
store_true
"
,
help
=
"
tune configs
"
)
args
=
parser
.
parse_args
()
batch
,
heads
,
groups
,
seq_len
,
chunk_size
,
dim
,
dstate
=
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
seq_len
,
args
.
chunk_size
,
args
.
dim
,
args
.
dstate
batch
,
heads
,
groups
,
seq_len
,
chunk_size
,
dim
,
dstate
=
(
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
seq_len
,
args
.
chunk_size
,
args
.
dim
,
args
.
dstate
,
)
total_flops
=
2
*
batch
*
seq_len
*
heads
*
dim
*
dstate
if
(
not
args
.
tune
)
:
if
not
args
.
tune
:
kernel
=
chunk_state_fwd
(
batch
,
seq_len
,
chunk_size
,
groups
,
heads
,
dim
,
dstate
,
block_M
=
64
,
block_N
=
128
,
block_K
=
64
,
num_stages
=
4
,
threads
=
128
)
batch
,
seq_len
,
chunk_size
,
groups
,
heads
,
dim
,
dstate
,
block_M
=
64
,
block_N
=
128
,
block_K
=
64
,
num_stages
=
4
,
threads
=
128
)
profiler
=
kernel
.
get_profiler
(
tilelang
.
TensorSupplyType
.
Normal
)
profiler
.
assert_allclose
(
ref_program
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"All checks pass."
)
...
...
examples/linear_attention/example_retention_fwd.py
View file @
29051439
...
...
@@ -13,13 +13,12 @@ def chunk_retention_fwd_kernel(
H
,
DK
,
DV
,
dtype
:
str
=
'
float16
'
,
dtype
:
str
=
"
float16
"
,
scale
:
float
=
None
,
)
->
torch
.
Tensor
:
if
scale
is
None
:
scale
=
DK
**-
0.5
accum_dtype
=
'
float
'
accum_dtype
=
"
float
"
chunk_size
=
64
BK
=
BV
=
64
# Set to 128 can be faster, but has some numerical differences with FLA
...
...
@@ -38,8 +37,8 @@ def chunk_retention_fwd_kernel(
with
T
.
Kernel
(
NV
,
NK
,
B
*
H
)
as
(
i_v
,
i_k
,
i_bh
):
i_b
=
i_bh
//
H
i_h
=
i_bh
%
H
log_decay
=
T
.
alloc_var
(
'
float32
'
)
log_decay
=
T
.
log2
(
1
-
T
.
exp2
(
-
5.
-
1.
*
i_h
))
# Head-specific log decay
log_decay
=
T
.
alloc_var
(
"
float32
"
)
log_decay
=
T
.
log2
(
1
-
T
.
exp2
(
-
5.
0
-
1.
0
*
i_h
))
# Head-specific log decay
q
=
T
.
alloc_shared
([
chunk_size
,
BK
],
dtype
)
k
=
T
.
alloc_shared
([
chunk_size
,
BK
],
dtype
)
...
...
@@ -56,14 +55,12 @@ def chunk_retention_fwd_kernel(
for
i
in
T
.
Pipelined
(
0
,
NT
):
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
BK
):
q
[
row
,
col
]
=
Q
[
i_b
,
i
*
chunk_size
+
row
,
i_h
,
i_k
*
BK
+
col
]
*
scale
T
.
copy
(
K
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
V
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
v
)
T
.
copy
(
K
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
V
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
v
)
T
.
gemm
(
q
,
k
,
s
,
clear_accum
=
True
,
transpose_B
=
True
)
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
chunk_size
):
s_shared
[
row
,
col
]
=
T
.
if_then_else
(
row
>=
col
,
s
[
row
,
col
]
*
T
.
exp2
(
(
row
-
col
)
*
log_decay
),
0
)
s_shared
[
row
,
col
]
=
T
.
if_then_else
(
row
>=
col
,
s
[
row
,
col
]
*
T
.
exp2
((
row
-
col
)
*
log_decay
),
0
)
T
.
copy
(
h
,
h_shared
)
T
.
gemm
(
q
,
h_shared
,
o
,
clear_accum
=
True
)
...
...
@@ -75,9 +72,7 @@ def chunk_retention_fwd_kernel(
v
[
row
,
col
]
=
v
[
row
,
col
]
*
T
.
exp2
((
chunk_size
-
row
-
1
)
*
log_decay
)
for
row
,
col
in
T
.
Parallel
(
BK
,
BV
):
h
[
row
,
col
]
=
T
.
exp2
(
chunk_size
*
log_decay
)
*
h
[
row
,
col
]
T
.
copy
(
o
,
O
[
i_k
,
i_b
,
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
])
T
.
copy
(
o
,
O
[
i_k
,
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
])
T
.
gemm
(
k
,
v
,
h
,
transpose_A
=
True
)
return
chunk_retention_fwd
...
...
@@ -89,24 +84,24 @@ def postprocess(o):
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--B
'
,
type
=
int
,
default
=
8
,
help
=
'
Batch size
'
)
parser
.
add_argument
(
'
--S
'
,
type
=
int
,
default
=
4096
,
help
=
'
Seq len
'
)
parser
.
add_argument
(
'
--H
'
,
type
=
int
,
default
=
32
,
help
=
'
Num heads
'
)
parser
.
add_argument
(
'
--D
'
,
type
=
int
,
default
=
128
,
help
=
'
Head dim
'
)
parser
.
add_argument
(
"
--B
"
,
type
=
int
,
default
=
8
,
help
=
"
Batch size
"
)
parser
.
add_argument
(
"
--S
"
,
type
=
int
,
default
=
4096
,
help
=
"
Seq len
"
)
parser
.
add_argument
(
"
--H
"
,
type
=
int
,
default
=
32
,
help
=
"
Num heads
"
)
parser
.
add_argument
(
"
--D
"
,
type
=
int
,
default
=
128
,
help
=
"
Head dim
"
)
args
=
parser
.
parse_args
()
B
,
S
,
H
,
D
=
args
.
B
,
args
.
S
,
args
.
H
,
args
.
D
total_flops
=
2.0
*
B
*
S
*
S
*
H
*
D
# causal
q
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
kernel
=
chunk_retention_fwd_kernel
(
B
,
S
,
H
,
D
,
D
)
t
=
do_bench
(
lambda
:
postprocess
(
kernel
(
q
,
k
,
v
)),
warmup
=
25
,
rep
=
100
)
print
(
f
'
Tilelang latency:
{
t
:.
3
f
}
ms
'
)
print
(
f
'
Tilelang TFLOPs:
{
total_flops
/
t
*
1e-9
}
'
)
print
(
f
"
Tilelang latency:
{
t
:.
3
f
}
ms
"
)
print
(
f
"
Tilelang TFLOPs:
{
total_flops
/
t
*
1e-9
}
"
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
main
()
examples/minference/example_vertical_slash_sparse_attn.py
View file @
29051439
...
...
@@ -15,12 +15,11 @@ from tilelang.profiler import do_bench
@
tilelang
.
jit
(
out_idx
=
[
3
])
def
_tl_vs_sparse_flashattn
(
batch
,
heads
,
seq_len
,
dim
,
vertical_size
,
slash_size
):
block_M
=
64
block_N
=
64
num_stages
=
2
threads
=
128
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
seq_blocks
=
(
seq_len
+
block_M
-
1
)
//
block_M
...
...
@@ -30,15 +29,13 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
offset_shape
=
count_shape
+
[
slash_size
]
index_shape
=
count_shape
+
[
vertical_size
]
vertical_size_round
,
slash_size_round
=
tilelang
.
next_power_of_2
(
vertical_size
),
tilelang
.
next_power_of_2
(
slash_size
)
vertical_size_round
,
slash_size_round
=
tilelang
.
next_power_of_2
(
vertical_size
),
tilelang
.
next_power_of_2
(
slash_size
)
dtype
=
"float16"
accum_dtype
=
"float"
int_dtype
=
"int32"
def
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
):
@
T
.
macro
def
Prefetch
(
K
:
T
.
Tensor
(
shape
,
dtype
),
...
...
@@ -53,13 +50,11 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
):
with
T
.
attr
(
"default"
,
"async_scope"
,
1
):
for
i
,
j
in
T
.
Parallel
(
block_N
,
dim
):
K_shared
[
i
,
j
]
=
T
.
if_then_else
(
k
+
i
<
column_count
,
K
[
bz
,
by
,
column_index
[
k
+
i
],
j
],
0
)
K_shared
[
i
,
j
]
=
T
.
if_then_else
(
k
+
i
<
column_count
,
K
[
bz
,
by
,
column_index
[
k
+
i
],
j
],
0
)
with
T
.
attr
(
"default"
,
"async_scope"
,
1
):
for
i
,
j
in
T
.
Parallel
(
block_N
,
dim
):
V_shared
[
i
,
j
]
=
T
.
if_then_else
(
k
+
i
<
column_count
,
V
[
bz
,
by
,
column_index
[
k
+
i
],
j
],
0
)
V_shared
[
i
,
j
]
=
T
.
if_then_else
(
k
+
i
<
column_count
,
V
[
bz
,
by
,
column_index
[
k
+
i
],
j
],
0
)
T
.
ptx_commit_group
()
...
...
@@ -118,7 +113,6 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
ColumnIndex
:
T
.
Tensor
(
index_shape
,
int_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
256
)
as
(
bc
,
by
,
bz
):
bx
=
T
.
ceildiv
(
seq_len
,
block_M
)
-
1
-
bc
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
2
,
block_N
,
dim
],
dtype
)
...
...
@@ -143,9 +137,11 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
T
.
create_list_of_mbarrier
([
128
]
*
9
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
})
}
)
block_count
[
0
]
=
BlockCount
[
bz
,
by
,
bx
]
column_count
[
0
]
=
ColumnCount
[
bz
,
by
,
bx
]
...
...
@@ -162,15 +158,15 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
if
tid
>=
128
:
T
.
annotate_producer_reg_dealloc
()
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
mbarrier_arrive
(
mbarrier
=
8
)
for
bi
in
T
.
serial
(
block_count
[
0
]):
k
=
block_offset
[
bi
]
T
.
mbarrier_wait_parity
(
mbarrier
=
bi
%
2
+
4
,
parity
=
(((
bi
&
3
)
>>
1
)
^
1
))
T
.
copy
(
K
[
bz
,
by
,
k
:
k
+
block_N
,
:],
K_shared
[
bi
%
2
,
:,
:])
T
.
copy
(
K
[
bz
,
by
,
k
:
k
+
block_N
,
:],
K_shared
[
bi
%
2
,
:,
:])
T
.
mbarrier_arrive
(
mbarrier
=
bi
%
2
)
T
.
mbarrier_wait_parity
(
mbarrier
=
bi
%
2
+
6
,
parity
=
(((
bi
&
3
)
>>
1
)
^
1
))
T
.
copy
(
V
[
bz
,
by
,
k
:
k
+
block_N
,
:],
V_shared
[
bi
%
2
,
:,
:])
T
.
copy
(
V
[
bz
,
by
,
k
:
k
+
block_N
,
:],
V_shared
[
bi
%
2
,
:,
:])
T
.
mbarrier_arrive
(
mbarrier
=
bi
%
2
+
2
)
else
:
T
.
annotate_consumer_reg_alloc
()
...
...
@@ -181,16 +177,10 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
for
bi
in
T
.
serial
(
block_count
[
0
]):
k
=
block_offset
[
bi
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
T
.
mbarrier_wait_parity
(
mbarrier
=
bi
%
2
,
parity
=
((
bi
&
3
)
>>
1
))
T
.
gemm
(
Q_shared
,
K_shared
[
bi
%
2
,
:,
:],
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
[
bi
%
2
,
:,
:],
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
mbarrier_arrive
(
mbarrier
=
bi
%
2
+
4
)
T
.
copy
(
scores_max
,
scores_max_prev
)
...
...
@@ -200,20 +190,15 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
=
acc_o
[
i
,
j
]
*
scores_scale
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
T
.
mbarrier_wait_parity
(
mbarrier
=
bi
%
2
+
2
,
parity
=
(((
bi
&
3
)
>>
1
)))
T
.
gemm
(
acc_s_cast
,
V_shared
[
bi
%
2
,
:,
:],
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
mbarrier_wait_parity
(
mbarrier
=
bi
%
2
+
2
,
parity
=
((
bi
&
3
)
>>
1
))
T
.
gemm
(
acc_s_cast
,
V_shared
[
bi
%
2
,
:,
:],
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
mbarrier_arrive
(
mbarrier
=
bi
%
2
+
6
)
...
...
@@ -223,38 +208,85 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
if
column_count
[
0
]
!=
0
:
Prefetch
(
K
,
V
,
K_shared_1
,
V_shared_1
,
column_index
,
column_count
[
0
],
0
,
bz
,
by
)
Prefetch
(
K
,
V
,
K_shared_1
,
V_shared_1
,
column_index
,
column_count
[
0
],
0
,
bz
,
by
)
for
bi
in
T
.
serial
(
T
.
ceildiv
(
column_count
[
0
],
block_N
)
-
1
):
k
=
bi
*
block_N
if
bi
%
2
==
0
:
Prefetch
(
K
,
V
,
K_shared_2
,
V_shared_2
,
column_index
,
column_count
[
0
],
k
+
block_N
,
bz
,
by
)
Prefetch
(
K
,
V
,
K_shared_2
,
V_shared_2
,
column_index
,
column_count
[
0
],
k
+
block_N
,
bz
,
by
)
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
k
,
column_count
[
0
],
Q_shared
,
K_shared_1
,
V_shared_1
,
scores_scale
,
scores_sum
,
logsum
,
1
)
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
k
,
column_count
[
0
],
Q_shared
,
K_shared_1
,
V_shared_1
,
scores_scale
,
scores_sum
,
logsum
,
1
,
)
else
:
Prefetch
(
K
,
V
,
K_shared_1
,
V_shared_1
,
column_index
,
column_count
[
0
],
k
+
block_N
,
bz
,
by
)
Prefetch
(
K
,
V
,
K_shared_1
,
V_shared_1
,
column_index
,
column_count
[
0
],
k
+
block_N
,
bz
,
by
)
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
k
,
column_count
[
0
],
Q_shared
,
K_shared_2
,
V_shared_2
,
scores_scale
,
scores_sum
,
logsum
,
1
)
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
k
,
column_count
[
0
],
Q_shared
,
K_shared_2
,
V_shared_2
,
scores_scale
,
scores_sum
,
logsum
,
1
,
)
if
T
.
ceildiv
(
column_count
[
0
],
block_N
)
%
2
==
0
:
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
T
.
ceildiv
(
column_count
[
0
],
block_N
)
*
block_N
-
block_N
,
column_count
[
0
],
Q_shared
,
K_shared_2
,
V_shared_2
,
scores_scale
,
scores_sum
,
logsum
,
0
)
column_count
[
0
],
Q_shared
,
K_shared_2
,
V_shared_2
,
scores_scale
,
scores_sum
,
logsum
,
0
,
)
else
:
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
T
.
ceildiv
(
column_count
[
0
],
block_N
)
*
block_N
-
block_N
,
column_count
[
0
],
Q_shared
,
K_shared_1
,
V_shared_1
,
scores_scale
,
scores_sum
,
logsum
,
0
)
column_count
[
0
],
Q_shared
,
K_shared_1
,
V_shared_1
,
scores_scale
,
scores_sum
,
logsum
,
0
,
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
vs_sparse_flashattn_ws
...
...
@@ -470,11 +502,8 @@ def vertical_slash_sparse_attention(
import
os
current_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sources
=
[
os
.
path
.
join
(
current_dir
,
'ops'
,
'kernels.cpp'
),
os
.
path
.
join
(
current_dir
,
'ops'
,
'vertical_slash_index.cu'
)
]
ops
=
load
(
name
=
'convert'
,
sources
=
sources
,
verbose
=
False
)
sources
=
[
os
.
path
.
join
(
current_dir
,
"ops"
,
"kernels.cpp"
),
os
.
path
.
join
(
current_dir
,
"ops"
,
"vertical_slash_index.cu"
)]
ops
=
load
(
name
=
"convert"
,
sources
=
sources
,
verbose
=
False
)
convert_vertical_slash_indexes
=
ops
.
convert_vertical_slash_indexes
batch_size
,
num_heads
,
context_size
,
head_dim
=
query
.
shape
pad
=
(
block_size_M
-
context_size
)
&
(
block_size_M
-
1
)
...
...
@@ -485,15 +514,13 @@ def vertical_slash_sparse_attention(
value
=
torch
.
nn
.
functional
.
pad
(
value
,
[
0
,
0
,
0
,
pad
,
0
,
0
,
0
,
0
])
if
head_dim
not
in
[
16
,
32
,
64
,
128
,
256
,
512
]:
target_dim
=
2
**
math
.
ceil
(
math
.
log2
(
head_dim
))
-
head_dim
target_dim
=
2
**
math
.
ceil
(
math
.
log2
(
head_dim
))
-
head_dim
query
=
torch
.
nn
.
functional
.
pad
(
query
,
[
0
,
target_dim
,
0
,
0
,
0
,
0
,
0
,
0
])
key
=
torch
.
nn
.
functional
.
pad
(
key
,
[
0
,
target_dim
,
0
,
0
,
0
,
0
,
0
,
0
])
value
=
torch
.
nn
.
functional
.
pad
(
value
,
[
0
,
target_dim
,
0
,
0
,
0
,
0
,
0
,
0
])
v_idx
=
v_idx
.
to
(
torch
.
int32
).
reshape
((
batch_size
,
num_heads
,
-
1
)).
sort
(
dim
=-
1
,
descending
=
False
)[
0
]
s_idx
=
s_idx
.
to
(
torch
.
int32
).
reshape
((
batch_size
,
num_heads
,
-
1
)).
sort
(
dim
=-
1
,
descending
=
True
)[
0
]
v_idx
=
v_idx
.
to
(
torch
.
int32
).
reshape
((
batch_size
,
num_heads
,
-
1
)).
sort
(
dim
=-
1
,
descending
=
False
)[
0
]
s_idx
=
s_idx
.
to
(
torch
.
int32
).
reshape
((
batch_size
,
num_heads
,
-
1
)).
sort
(
dim
=-
1
,
descending
=
True
)[
0
]
seqlens
=
torch
.
tensor
([
context_size
]
*
query
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
query
.
device
)
sm_scale
=
head_dim
**-
0.5
...
...
@@ -506,8 +533,7 @@ def vertical_slash_sparse_attention(
block_size_N
,
)
tl_kernel
=
_tl_vs_sparse_flashattn
(
batch_size
,
num_heads
,
context_size
,
head_dim
,
v_idx
.
shape
[
2
],
s_idx
.
shape
[
2
])
tl_kernel
=
_tl_vs_sparse_flashattn
(
batch_size
,
num_heads
,
context_size
,
head_dim
,
v_idx
.
shape
[
2
],
s_idx
.
shape
[
2
])
def
run
(
is_triton
:
bool
=
True
):
if
is_triton
:
...
...
@@ -525,8 +551,7 @@ def vertical_slash_sparse_attention(
block_size_N
,
)
else
:
out
=
tl_kernel
(
query
,
key
,
value
,
block_count
,
block_offset
,
column_count
,
column_index
)
out
=
tl_kernel
(
query
,
key
,
value
,
block_count
,
block_offset
,
column_count
,
column_index
)
return
out
[...,
:
context_size
,
:
head_dim
]
return
run
...
...
@@ -536,8 +561,7 @@ def sum_all_diagonal_matrix(mat: torch.tensor):
b
,
h
,
n
,
m
=
mat
.
shape
zero_mat
=
torch
.
zeros
((
b
,
h
,
n
,
n
)).
to
(
mat
.
device
)
# Zero matrix used for padding
mat_padded
=
torch
.
cat
((
zero_mat
,
mat
,
zero_mat
),
-
1
)
# pads the matrix on left and right
mat_strided
=
mat_padded
.
as_strided
(
(
1
,
1
,
n
,
n
+
m
),
(
1
,
n
*
(
2
*
n
+
m
),
2
*
n
+
m
+
1
,
1
))
# Change the strides
mat_strided
=
mat_padded
.
as_strided
((
1
,
1
,
n
,
n
+
m
),
(
1
,
n
*
(
2
*
n
+
m
),
2
*
n
+
m
+
1
,
1
))
# Change the strides
sum_diags
=
torch
.
sum
(
mat_strided
,
2
)
# Sums the resulting matrix's columns
return
sum_diags
[:,
:,
1
:]
...
...
@@ -559,24 +583,23 @@ def main(argv=None):
vertical_size
,
slash_size
=
args
.
vertical_size
,
args
.
slash_size
torch
.
manual_seed
(
0
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
q_len
=
SEQ_LEN
vertical_size
,
slash_size
=
min
(
q_len
,
vertical_size
),
min
(
q_len
,
slash_size
)
last_q
=
64
qk
=
torch
.
einsum
(
'
bhmk, bhnk -> bhmn
'
,
q
[:,
:,
-
last_q
:,
:],
k
)
qk
=
torch
.
einsum
(
"
bhmk, bhnk -> bhmn
"
,
q
[:,
:,
-
last_q
:,
:],
k
)
arange
=
torch
.
arange
(
last_q
,
device
=
"cuda"
)
qk
[:,
:,
:,
-
last_q
:]
=
torch
.
where
(
arange
[
None
,
None
,
:,
None
]
>=
arange
[
None
,
None
,
None
,
:],
qk
[:,
:,
:,
-
last_q
:],
-
torch
.
inf
)
qk
[:,
:,
:,
-
last_q
:]
=
torch
.
where
(
arange
[
None
,
None
,
:,
None
]
>=
arange
[
None
,
None
,
None
,
:],
qk
[:,
:,
:,
-
last_q
:],
-
torch
.
inf
)
qk
=
torch
.
nn
.
functional
.
softmax
(
qk
,
dim
=-
1
,
dtype
=
torch
.
float32
)
vertical
=
qk
.
sum
(
-
2
,
keepdim
=
True
)
vertical
[...,
:
30
]
=
torch
.
inf
vertical_topk
=
torch
.
topk
(
vertical
,
vertical_size
,
-
1
).
indices
slash
=
sum_all_diagonal_matrix
(
qk
)[...,
:
-
last_q
+
1
]
slash
=
sum_all_diagonal_matrix
(
qk
)[...,
:
-
last_q
+
1
]
slash
[...,
-
30
:]
=
torch
.
inf
slash
=
(
q_len
-
1
)
-
torch
.
topk
(
slash
,
slash_size
,
-
1
).
indices
...
...
examples/norm/rms_norm.py
View file @
29051439
...
...
@@ -45,7 +45,7 @@ def rms_norm(M, N, blk_m):
A_local
=
T
.
alloc_fragment
((
blk_m
,
N
),
dtype
)
A_powsum
=
T
.
alloc_fragment
((
blk_m
,),
dtype
)
T
.
copy
(
A
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:],
A_shared
)
T
.
copy
(
A
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:],
A_shared
)
T
.
copy
(
A_shared
,
A_local
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
A_pow_local
[
i
,
j
]
=
A_local
[
i
,
j
]
*
A_local
[
i
,
j
]
...
...
@@ -54,7 +54,7 @@ def rms_norm(M, N, blk_m):
A_powsum
[
i
]
=
T
.
rsqrt
(
A_powsum
[
i
]
/
N
+
1e-12
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
A_local
[
i
,
j
]
*=
A_powsum
[
i
]
T
.
copy
(
A_local
,
B
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:])
T
.
copy
(
A_local
,
B
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:])
return
main
...
...
examples/norm/test_rms_norm.py
View file @
29051439
...
...
@@ -45,7 +45,7 @@ def rms_norm(M, N, blk_m):
A_local
=
T
.
alloc_fragment
((
blk_m
,
N
),
dtype
)
A_powsum
=
T
.
alloc_fragment
((
blk_m
,),
dtype
)
T
.
copy
(
A
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:],
A_shared
)
T
.
copy
(
A
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:],
A_shared
)
T
.
copy
(
A_shared
,
A_local
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
A_pow_local
[
i
,
j
]
=
A_local
[
i
,
j
]
*
A_local
[
i
,
j
]
...
...
@@ -54,7 +54,7 @@ def rms_norm(M, N, blk_m):
A_powsum
[
i
]
=
T
.
rsqrt
(
A_powsum
[
i
]
/
N
+
1e-12
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
A_local
[
i
,
j
]
*=
A_powsum
[
i
]
T
.
copy
(
A_local
,
B
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:])
T
.
copy
(
A_local
,
B
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:])
return
main
...
...
examples/online_softmax/online_softmax.py
View file @
29051439
...
...
@@ -33,7 +33,7 @@ def softmax_kernel(
T
.
fill
(
lse
,
-
T
.
infinity
(
accum_dtype
))
for
i_n
in
T
.
Pipelined
(
0
,
NN
):
T
.
copy
(
X
[
i_m
,
i_n
*
BN
:
(
i_n
+
1
)
*
BN
],
x
)
T
.
copy
(
X
[
i_m
,
i_n
*
BN
:
(
i_n
+
1
)
*
BN
],
x
)
T
.
reduce_max
(
x
,
max_x
,
dim
=
0
,
clear
=
True
)
...
...
@@ -45,12 +45,12 @@ def softmax_kernel(
lse
[
0
]
=
max_x
[
0
]
*
scale
+
T
.
log2
(
T
.
exp2
(
lse
[
0
]
-
max_x
[
0
]
*
scale
)
+
sum_exp_x
[
0
])
for
i_n
in
T
.
Pipelined
(
0
,
NN
):
T
.
copy
(
X
[
i_m
,
i_n
*
BN
:
(
i_n
+
1
)
*
BN
],
x
)
T
.
copy
(
X
[
i_m
,
i_n
*
BN
:
(
i_n
+
1
)
*
BN
],
x
)
for
j
in
T
.
Parallel
(
BN
):
y
[
j
]
=
T
.
exp2
(
x
[
j
]
*
scale
-
lse
[
0
])
T
.
copy
(
y
,
Y
[
i_m
,
i_n
*
BN
:
(
i_n
+
1
)
*
BN
])
T
.
copy
(
y
,
Y
[
i_m
,
i_n
*
BN
:
(
i_n
+
1
)
*
BN
])
return
main
...
...
@@ -69,4 +69,4 @@ t1 = do_bench(lambda: X.softmax(dim=1), warmup=25, rep=100)
t2
=
do_bench
(
lambda
:
kernel
(
X
),
warmup
=
25
,
rep
=
100
)
print
(
f
"torch latency:
{
t1
:.
3
f
}
ms"
)
print
(
f
"TileLang latency:
{
t2
:.
3
f
}
ms"
)
print
(
f
"Speedup:
{
t1
/
t2
:.
3
f
}
x"
)
print
(
f
"Speedup:
{
t1
/
t2
:.
3
f
}
x"
)
examples/plot_layout/fragment_mfma_load_a.py
View file @
29051439
...
...
@@ -11,10 +11,9 @@ from tilelang.intrinsics.mfma_layout import (
)
def
make_mfma_load_base_layout
(
dtype
:
str
=
"float16"
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
,
k_dim
:
int
=
16
,
transposed
:
bool
=
False
)
->
T
.
Fragment
:
def
make_mfma_load_base_layout
(
dtype
:
str
=
"float16"
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
,
k_dim
:
int
=
16
,
transposed
:
bool
=
False
)
->
T
.
Fragment
:
"""
Create a layout function for storing MFMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mfma_store_layout` to
...
...
@@ -72,12 +71,10 @@ def make_mfma_load_base_layout(dtype: str = "float16",
# so the b matrix expected a transposed basic layout
transform_func
:
Callable
=
None
if
matrix
==
"A"
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
micro_size_s
,
micro_size_r
=
micro_size_x
,
micro_size_k
elif
matrix
==
"B"
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
micro_size_s
,
micro_size_r
=
micro_size_k
,
micro_size_y
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
...
...
@@ -120,14 +117,11 @@ print(base_layout)
plot_layout
(
base_layout
,
name
=
"base_layout"
)
# warp layout 32x32
warp_layout
=
base_layout
.
repeat
([
warp_rows
,
warp_cols
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
warp_layout
=
base_layout
.
repeat
([
warp_rows
,
warp_cols
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
print
(
warp_layout
)
plot_layout
(
warp_layout
,
name
=
"warp_layout"
)
# block layout 64x32
block_layout
=
warp_layout
.
repeat
([
block_rows
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
block_cols
)
block_layout
=
warp_layout
.
repeat
([
block_rows
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
block_cols
)
print
(
block_layout
)
plot_layout
(
block_layout
,
name
=
"block_layout"
)
examples/plot_layout/fragment_mma_load_a.py
View file @
29051439
...
...
@@ -5,9 +5,7 @@ from tvm.tir import IndexMap
from
tilelang.intrinsics.utils
import
get_mma_micro_size
def
make_mma_load_base_layout
(
dtype
:
str
=
"float16"
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
,
transposed
:
bool
=
False
)
->
T
.
Fragment
:
def
make_mma_load_base_layout
(
dtype
:
str
=
"float16"
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
,
transposed
:
bool
=
False
)
->
T
.
Fragment
:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
...
...
@@ -36,6 +34,7 @@ def make_mma_load_base_layout(dtype: str = "float16",
shared_16x16_to_mma_32x8_layout_sr_b
,
shared_16x32_to_mma_32x16_layout_sr_b
,
)
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
dtype_bits
=
DataType
(
dtype
).
bits
# s represents spatial axis
...
...
@@ -67,12 +66,10 @@ def make_mma_load_base_layout(dtype: str = "float16",
# so the b matrix expected a transposed basic layout
transform_func
:
Callable
=
None
if
matrix
==
"A"
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
micro_size_s
,
micro_size_r
=
micro_size_x
,
micro_size_k
elif
matrix
==
"B"
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
micro_size_s
,
micro_size_r
=
micro_size_k
,
micro_size_y
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
...
...
examples/quickstart.py
View file @
29051439
...
...
@@ -7,7 +7,6 @@ import tilelang.language as T
# if not specified, it will be inferred from the input tensors during compile time
@
tilelang
.
jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
matmul_relu_kernel
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
examples/seer_attention/block_sparse_attn_tilelang.py
View file @
29051439
...
...
@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
...
@@ -30,15 +27,17 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
@
tilelang
.
jit
(
out_idx
=
[
4
],
pass_configs
=
{
out_idx
=
[
4
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
blocksparse_flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
downsample_len
,
is_causal
):
block_M
=
64
block_N
=
64
num_stages
=
0
threads
=
128
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
q_shape
=
[
batch
,
heads
,
seq_q
,
dim
]
kv_shape
=
[
batch
,
heads
,
seq_kv
,
dim
]
block_mask_shape
=
[
batch
,
heads
,
downsample_len
,
downsample_len
]
...
...
@@ -48,7 +47,6 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
block_mask_dtype
=
"int8"
def
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
):
@
T
.
macro
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
...
...
@@ -112,7 +110,7 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
block_mask
=
T
.
alloc_local
([
downsample_len
],
block_mask_dtype
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -124,33 +122,25 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
if
block_mask
[
k
]
!=
0
:
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
past_len
=
seq_kv
-
seq_q
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
+
past_len
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
+
past_len
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
main
...
...
@@ -165,44 +155,40 @@ def test_topk_sparse_attention():
torch
.
manual_seed
(
0
)
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
# Create sparse mask (downsampled to block level)
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
'cuda'
,
dtype
=
torch
.
float16
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
float16
)
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
# Run tilelang kernel
kernel
=
blocksparse_flashattn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
SEQ_LEN
,
D_HEAD
,
downsample_len
,
is_causal
=
True
)
kernel
=
blocksparse_flashattn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
SEQ_LEN
,
D_HEAD
,
downsample_len
,
is_causal
=
True
)
tilelang_output
=
kernel
(
q
,
k
,
v
,
block_mask
.
to
(
torch
.
int8
))
# Compute reference
# Expand block mask to full attention matrix
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
))
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
))
full_mask
=
full_mask
[...,
:
SEQ_LEN
,
:
SEQ_LEN
].
bool
()
full_mask
=
full_mask
&
torch
.
tril
(
torch
.
ones_like
(
full_mask
))
# Apply causal
# PyTorch reference implementation
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
'
-inf
'
))
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
print
(
"ref_output"
,
ref_output
)
print
(
"tilelang_output"
,
tilelang_output
)
# Verify accuracy
assert
torch
.
allclose
(
tilelang_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
"TileLang output doesn't match reference"
assert
torch
.
allclose
(
tilelang_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
"TileLang output doesn't match reference"
print
(
"Pass topk sparse attention test with qlen == klen"
)
...
...
@@ -215,42 +201,40 @@ def test_topk_sparse_attention_qlen_lt_klen():
torch
.
manual_seed
(
0
)
# Create inputs.
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
Q_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
Q_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
K_LEN
/
downsample_factor
)
# number of blocks along one dimension
x_ds
=
torch
.
randn
(
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
,
device
=
'cuda'
,
dtype
=
torch
.
float16
)
x_ds
=
torch
.
randn
(
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
# Force the first column to be high so that the first block is always selected.
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
kernel
=
blocksparse_flashattn
(
BATCH
,
N_HEADS
,
Q_LEN
,
K_LEN
,
D_HEAD
,
downsample_len
,
is_causal
=
True
)
kernel
=
blocksparse_flashattn
(
BATCH
,
N_HEADS
,
Q_LEN
,
K_LEN
,
D_HEAD
,
downsample_len
,
is_causal
=
True
)
print
(
kernel
.
get_kernel_source
())
tilelang_output
=
kernel
(
q
,
k
,
v
,
block_mask
.
to
(
torch
.
int8
))
past_len
=
K_LEN
-
Q_LEN
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
full_mask_full
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
)).
bool
()
full_mask_full
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
)).
bool
()
full_mask_full
=
full_mask_full
[...,
:
K_LEN
,
:
K_LEN
]
effective_mask
=
full_mask_full
[...,
past_len
:
K_LEN
,
:]
# shape: (B, H, Q_LEN, K_LEN)
i_global
=
torch
.
arange
(
past_len
,
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
1
)
# shape: (Q_LEN, 1)
j_global
=
torch
.
arange
(
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
0
)
# shape: (1, K_LEN)
causal_mask
=
(
j_global
<=
i_global
)
# shape: (Q_LEN, K_LEN)
causal_mask
=
j_global
<=
i_global
# shape: (Q_LEN, K_LEN)
final_mask
=
effective_mask
&
causal_mask
# shape: (B, H, Q_LEN, K_LEN)
attn
=
attn
.
masked_fill
(
~
final_mask
,
float
(
'
-inf
'
))
attn
=
attn
.
masked_fill
(
~
final_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
print
(
"ref_output"
,
ref_output
)
print
(
"tilelang_output"
,
tilelang_output
)
...
...
examples/seer_attention/block_sparse_attn_triton.py
View file @
29051439
...
...
@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
...
@@ -54,7 +51,6 @@ def _fwd_kernel_inner(
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
mask_val
=
tl
.
load
(
block_mask_ptr
+
k_block_col_idx
*
stride_bmask_n
)
if
mask_val
==
True
:
...
...
@@ -69,7 +65,7 @@ def _fwd_kernel_inner(
qk
*=
sm_scale
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
qk
+=
tl
.
where
(
offs_m
[:,
None
]
+
past_len
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
float
(
'
-inf
'
))
qk
+=
tl
.
where
(
offs_m
[:,
None
]
+
past_len
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
float
(
"
-inf
"
))
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
qk
-=
m_ij
[:,
None
]
...
...
@@ -149,7 +145,7 @@ def _fwd_kernel(
v_ptrs
=
V
+
off_v
mask_ptrs
=
block_mask_ptr
+
start_m
*
stride_bmm
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
'
inf
'
)
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"
inf
"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
...
...
@@ -185,24 +181,12 @@ def _fwd_kernel(
acc
=
acc
*
l_recip
acc
=
acc
.
to
(
Out
.
dtype
.
element_ty
)
off_o
=
off_z
*
stride_oz
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
None
,
:]
*
stride_od
off_o
=
off_z
*
stride_oz
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
None
,
:]
*
stride_od
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
N_CTX
)
def
_forward
(
ctx
,
q
,
k
,
v
,
block_sparse_mask
,
sm_scale
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
num_warps
=
None
,
num_stages
=
1
,
out
=
None
):
def
_forward
(
ctx
,
q
,
k
,
v
,
block_sparse_mask
,
sm_scale
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
num_warps
=
None
,
num_stages
=
1
,
out
=
None
):
assert
q
.
shape
[
-
1
]
==
k
.
shape
[
-
1
]
==
v
.
shape
[
-
1
]
assert
k
.
shape
[
2
]
==
v
.
shape
[
2
]
o
=
out
if
out
is
not
None
else
torch
.
empty_like
(
q
).
contiguous
()
...
...
@@ -247,7 +231,6 @@ def _forward(ctx,
class
_sparse_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
block_sparse_dense
,
sm_scale
):
# shape constraints
...
...
@@ -271,9 +254,9 @@ def test_topk_sparse_attention():
torch
.
manual_seed
(
0
)
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
# Create sparse mask (downsampled to block level)
...
...
@@ -281,9 +264,7 @@ def test_topk_sparse_attention():
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
print
(
"downsample_len"
,
downsample_len
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
x_ds
[:,
:,
:,
0
]
=
100
print
(
"x_ds.shape"
,
x_ds
.
shape
)
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
...
...
@@ -295,22 +276,21 @@ def test_topk_sparse_attention():
# Compute reference
# Expand block mask to full attention matrix
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
))
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
))
full_mask
=
full_mask
[...,
:
SEQ_LEN
,
:
SEQ_LEN
].
bool
()
full_mask
=
full_mask
&
torch
.
tril
(
torch
.
ones_like
(
full_mask
))
# Apply causal
# PyTorch reference implementation
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
'
-inf
'
))
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
# print("ref_output", ref_output)
# print("triton_output", triton_output)
# Verify accuracy
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
"Triton output doesn't match reference"
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
"Triton output doesn't match reference"
print
(
"Pass topk sparse attention test with qlen == klen"
)
...
...
@@ -322,16 +302,15 @@ def test_topk_sparse_attention_qlt_kl():
torch
.
manual_seed
(
0
)
# Create inputs.
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
Q_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
Q_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
# softmax scale
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
K_LEN
/
downsample_factor
)
# number of blocks along one dimension
x_ds
=
torch
.
randn
(
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
,
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
x_ds
=
torch
.
randn
(
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
# Force the first column to be high so that the first block is always selected.
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
...
...
@@ -340,26 +319,25 @@ def test_topk_sparse_attention_qlt_kl():
past_len
=
K_LEN
-
Q_LEN
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
full_mask_full
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
)).
bool
()
full_mask_full
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
)).
bool
()
full_mask_full
=
full_mask_full
[...,
:
K_LEN
,
:
K_LEN
]
effective_mask
=
full_mask_full
[...,
past_len
:
K_LEN
,
:]
# shape: (B, H, Q_LEN, K_LEN)
i_global
=
torch
.
arange
(
past_len
,
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
1
)
# shape: (Q_LEN, 1)
j_global
=
torch
.
arange
(
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
0
)
# shape: (1, K_LEN)
causal_mask
=
(
j_global
<=
i_global
)
# shape: (Q_LEN, K_LEN)
causal_mask
=
j_global
<=
i_global
# shape: (Q_LEN, K_LEN)
final_mask
=
effective_mask
&
causal_mask
# shape: (B, H, Q_LEN, K_LEN)
attn
=
attn
.
masked_fill
(
~
final_mask
,
float
(
'
-inf
'
))
attn
=
attn
.
masked_fill
(
~
final_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
# Verify accuracy.
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
"Triton output doesn't match reference when qlen < klen"
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
"Triton output doesn't match reference when qlen < klen"
print
(
"Pass topk sparse attention test with qlen < klen"
)
...
...
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
View file @
29051439
...
...
@@ -29,23 +29,21 @@ def matmul_sp(
@
T
.
prim_func
def
main
(
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
8
),
'
uint8
'
),
E
:
T
.
Tensor
((
M
,
K
//
8
),
"
uint8
"
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
8
),
'
uint8
'
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
8
),
"
uint8
"
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
annotate_layout
({
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
arch
=
"9.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
arch
=
"9.0"
,
block_k
=
block_K
),
})
T
.
annotate_layout
(
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
arch
=
"9.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
arch
=
"9.0"
,
block_k
=
block_K
),
}
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
8
],
E_shared
)
...
...
@@ -57,7 +55,7 @@ def matmul_sp(
return
main
def
generate_2_to_4_sparse_tensor
(
shape
,
dtype
=
torch
.
float32
,
device
=
'
cpu
'
):
def
generate_2_to_4_sparse_tensor
(
shape
,
dtype
=
torch
.
float32
,
device
=
"
cpu
"
):
if
shape
[
-
1
]
%
4
!=
0
:
raise
ValueError
(
"Last dimension must be divisible by 4 for 2:4 sparsity."
)
...
...
@@ -102,9 +100,9 @@ def run_gemm_sp(
num_threads
,
)
A
=
generate_2_to_4_sparse_tensor
((
M
,
K
),
dtype
=
torch
.
float16
,
device
=
'
cuda
'
)
A
=
generate_2_to_4_sparse_tensor
((
M
,
K
),
dtype
=
torch
.
float16
,
device
=
"
cuda
"
)
A_sparse
,
E
=
compress_sm90
(
A
,
block_k
=
block_K
,
transposed
=
False
)
B
=
torch
.
randn
((
K
,
N
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
B
=
torch
.
randn
((
K
,
N
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
C_sp
=
kernel
(
A_sparse
,
E
,
B
).
half
()
C
=
torch
.
matmul
(
A
,
B
)
...
...
examples/topk/example_topk.py
View file @
29051439
...
...
@@ -43,15 +43,12 @@ def tl_topk(
T
.
reduce_max
(
logits_frag
,
max_val
,
dim
=
1
,
clear
=
True
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
expand_max_idx
[
i
,
j
]
=
T
.
if_then_else
(
max_val
[
i
]
==
logits_frag
[
i
,
j
],
j
,
expand_max_idx
[
i
,
j
])
expand_max_idx
[
i
,
j
]
=
T
.
if_then_else
(
max_val
[
i
]
==
logits_frag
[
i
,
j
],
j
,
expand_max_idx
[
i
,
j
])
T
.
reduce_max
(
expand_max_idx
,
max_idx
,
dim
=
1
,
clear
=
True
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
logits_frag
[
i
,
j
]
=
T
.
if_then_else
(
max_val
[
i
]
==
logits_frag
[
i
,
j
],
-
10000.0
,
logits_frag
[
i
,
j
])
logits_frag
[
i
,
j
]
=
T
.
if_then_else
(
max_val
[
i
]
==
logits_frag
[
i
,
j
],
-
10000.0
,
logits_frag
[
i
,
j
])
for
i
in
T
.
Parallel
(
blk_m
):
topk_gates
[
bx
*
blk_m
+
i
,
k
]
=
max_val
[
i
]
...
...
@@ -61,7 +58,6 @@ def tl_topk(
def
ref_program
(
logits
,
top_k
):
top_k_gates
,
top_k_indices
=
logits
.
topk
(
top_k
,
dim
=
1
)
return
top_k_gates
,
top_k_indices
.
to
(
torch
.
int32
)
...
...
examples/visual_layout_inference/visual_layout_inference.py
View file @
29051439
...
...
@@ -7,10 +7,10 @@ import tilelang.language as T
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_LAYOUT_VISUALIZATION_ENABLE
:
True
,
tilelang
.
PassConfigKey
.
TL_LAYOUT_VISUALIZATION_FORMATS
:
"svg"
})
tilelang
.
PassConfigKey
.
TL_LAYOUT_VISUALIZATION_FORMATS
:
"svg"
,
},
)
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
gemm
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
@@ -49,12 +49,12 @@ def main():
print
(
"All check passed."
)
# print the layout visualization result and save figures to ./tmp.
'''
"""
C_local inferenced layout:
Shape: [32, 32] -> [8]
Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2
Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2]
'''
"""
if
__name__
==
"__main__"
:
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment