Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
627 additions
and
647 deletions
+627
-647
examples/hadamard_transform/example_hadamard.py
examples/hadamard_transform/example_hadamard.py
+15
-20
examples/lazy_jit/lazyjit.en.ipynb
examples/lazy_jit/lazyjit.en.ipynb
+41
-31
examples/lazy_jit/lazyjit.zh.ipynb
examples/lazy_jit/lazyjit.zh.ipynb
+41
-31
examples/linear_attention/example_linear_attn_bwd.py
examples/linear_attention/example_linear_attn_bwd.py
+55
-74
examples/linear_attention/example_linear_attn_fwd.py
examples/linear_attention/example_linear_attn_fwd.py
+40
-46
examples/linear_attention/example_mamba_chunk_scan.py
examples/linear_attention/example_mamba_chunk_scan.py
+110
-79
examples/linear_attention/example_mamba_chunk_state.py
examples/linear_attention/example_mamba_chunk_state.py
+57
-62
examples/linear_attention/example_retention_fwd.py
examples/linear_attention/example_retention_fwd.py
+22
-27
examples/minference/example_vertical_slash_sparse_attn.py
examples/minference/example_vertical_slash_sparse_attn.py
+123
-100
examples/norm/rms_norm.py
examples/norm/rms_norm.py
+2
-2
examples/norm/test_rms_norm.py
examples/norm/test_rms_norm.py
+2
-2
examples/online_softmax/online_softmax.py
examples/online_softmax/online_softmax.py
+6
-6
examples/plot_layout/fragment_mfma_load_a.py
examples/plot_layout/fragment_mfma_load_a.py
+7
-13
examples/plot_layout/fragment_mma_load_a.py
examples/plot_layout/fragment_mma_load_a.py
+4
-7
examples/quickstart.py
examples/quickstart.py
+3
-4
examples/seer_attention/block_sparse_attn_tilelang.py
examples/seer_attention/block_sparse_attn_tilelang.py
+48
-64
examples/seer_attention/block_sparse_attn_triton.py
examples/seer_attention/block_sparse_attn_triton.py
+24
-46
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
...s/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
+14
-16
examples/topk/example_topk.py
examples/topk/example_topk.py
+5
-9
examples/visual_layout_inference/visual_layout_inference.py
examples/visual_layout_inference/visual_layout_inference.py
+8
-8
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
examples/hadamard_transform/example_hadamard.py
View file @
29051439
...
@@ -17,7 +17,7 @@ def is_pow_of_2(n):
...
@@ -17,7 +17,7 @@ def is_pow_of_2(n):
def
hadamard
(
b
,
n
,
dtype
):
def
hadamard
(
b
,
n
,
dtype
):
assert
is_pow_of_2
(
n
),
"n must be a power of 2"
assert
is_pow_of_2
(
n
),
"n must be a power of 2"
assert
2
<=
n
<=
32768
,
"n must be in [2, 32768]"
assert
2
<=
n
<=
32768
,
"n must be in [2, 32768]"
elem_size
=
{
'
float32
'
:
4
,
'
float16
'
:
2
,
'
bfloat16
'
:
2
}[
dtype
]
elem_size
=
{
"
float32
"
:
4
,
"
float16
"
:
2
,
"
bfloat16
"
:
2
}[
dtype
]
logN
=
int
(
math
.
log2
(
n
))
logN
=
int
(
math
.
log2
(
n
))
threads
=
[
0
,
1
,
1
,
1
,
2
,
4
,
8
,
16
,
32
,
32
,
128
,
256
,
256
,
256
,
256
,
256
][
logN
]
threads
=
[
0
,
1
,
1
,
1
,
2
,
4
,
8
,
16
,
32
,
32
,
128
,
256
,
256
,
256
,
256
,
256
][
logN
]
...
@@ -40,23 +40,21 @@ def hadamard(b, n, dtype):
...
@@ -40,23 +40,21 @@ def hadamard(b, n, dtype):
# print(f'{exchange_round=}')
# print(f'{exchange_round=}')
@
T
.
macro
@
T
.
macro
def
warp_shfl
(
local
:
T
.
Tensor
((
thread_elem
,),
dtype
),
buf
:
T
.
Tensor
((
thread_elem
,),
dtype
),
def
warp_shfl
(
local
:
T
.
Tensor
((
thread_elem
,),
dtype
),
buf
:
T
.
Tensor
((
thread_elem
,),
dtype
),
round
:
int
):
round
:
int
):
tx
=
T
.
get_thread_binding
(
0
)
tx
=
T
.
get_thread_binding
(
0
)
for
i
in
T
.
serial
(
round
):
for
i
in
T
.
serial
(
round
):
tx_stride
=
1
<<
i
tx_stride
=
1
<<
i
another_tx
=
tx
^
tx_stride
another_tx
=
tx
^
tx_stride
sign
=
(
sign
=
(
tx
>>
i
)
&
1
# get i-th lowest bit of tx, which determines the operation type for shared[tx, :]
tx
>>
i
)
&
1
# get i-th lowest bit of tx, which determines the operation type for shared[tx, :]
for
j
in
T
.
Pipelined
(
thread_elem
,
num_stages
=
1
):
for
j
in
T
.
Pipelined
(
thread_elem
,
num_stages
=
1
):
buf
[
j
]
=
T
.
tvm_warp_shuffle
(
buf
[
j
]
=
T
.
tvm_warp_shuffle
(
0x
ffffffff
,
# mask of all threads
0x
FFFFFFFF
,
# mask of all threads
local
[
j
],
local
[
j
],
another_tx
%
warp_size
,
another_tx
%
warp_size
,
warp_size
,
warp_size
,
warp_size
)
warp_size
,
)
local
[
j
]
=
T
.
if_then_else
(
sign
==
0
,
local
[
j
]
+
buf
[
j
],
buf
[
j
]
-
local
[
j
])
local
[
j
]
=
T
.
if_then_else
(
sign
==
0
,
local
[
j
]
+
buf
[
j
],
buf
[
j
]
-
local
[
j
])
@
T
.
prim_func
@
T
.
prim_func
...
@@ -78,10 +76,8 @@ def hadamard(b, n, dtype):
...
@@ -78,10 +76,8 @@ def hadamard(b, n, dtype):
for
j
in
T
.
serial
(
chunknum
):
for
j
in
T
.
serial
(
chunknum
):
chunkbase
=
j
*
chunksize
chunkbase
=
j
*
chunksize
for
k
in
T
.
serial
(
chunksize
//
2
):
for
k
in
T
.
serial
(
chunksize
//
2
):
local
[
chunkbase
+
local
[
chunkbase
+
k
]
=
local
[
chunkbase
+
k
]
+
local
[
chunkbase
+
k
+
chunksize
//
2
]
k
]
=
local
[
chunkbase
+
k
]
+
local
[
chunkbase
+
k
+
chunksize
//
2
]
local
[
chunkbase
+
k
+
chunksize
//
2
]
=
local
[
chunkbase
+
k
]
-
2
*
local
[
chunkbase
+
k
+
chunksize
//
2
]
local
[
chunkbase
+
k
+
chunksize
//
2
]
=
local
[
chunkbase
+
k
]
-
2
*
local
[
chunkbase
+
k
+
chunksize
//
2
]
# 3. Hadamard inside warp, n<=512
# 3. Hadamard inside warp, n<=512
# In warp level, we rely on warp shuffle to exchange data inside each warp, without using shared memory
# In warp level, we rely on warp shuffle to exchange data inside each warp, without using shared memory
...
@@ -131,28 +127,27 @@ def ref_program(x: torch.Tensor):
...
@@ -131,28 +127,27 @@ def ref_program(x: torch.Tensor):
assert
x
.
ndim
==
2
assert
x
.
ndim
==
2
dim
=
x
.
shape
[
-
1
]
dim
=
x
.
shape
[
-
1
]
assert
is_pow_of_2
(
dim
)
assert
is_pow_of_2
(
dim
)
return
F
.
linear
(
return
F
.
linear
(
x
,
torch
.
tensor
(
scipy
.
linalg
.
hadamard
(
dim
,
dtype
=
float
),
dtype
=
x
.
dtype
,
device
=
x
.
device
))
x
,
torch
.
tensor
(
scipy
.
linalg
.
hadamard
(
dim
,
dtype
=
float
),
dtype
=
x
.
dtype
,
device
=
x
.
device
))
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
64
,
help
=
'
Batch size
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
64
,
help
=
"
Batch size
"
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
32768
,
help
=
'
Dimension
'
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
32768
,
help
=
"
Dimension
"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
B
,
D
=
args
.
batch
,
args
.
dim
B
,
D
=
args
.
batch
,
args
.
dim
x
=
torch
.
randn
((
B
,
D
),
device
=
'
cuda
'
)
x
=
torch
.
randn
((
B
,
D
),
device
=
"
cuda
"
)
kernel
=
hadamard
(
B
,
D
,
'
float32
'
)
kernel
=
hadamard
(
B
,
D
,
"
float32
"
)
y
=
kernel
(
x
)
y
=
kernel
(
x
)
y_ref
=
ref_program
(
x
)
y_ref
=
ref_program
(
x
)
torch
.
testing
.
assert_close
(
y
,
y_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
y
,
y_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
print
(
'
All tests passed.
'
)
print
(
"
All tests passed.
"
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Auto
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Auto
)
latency
=
profiler
.
do_bench
(
warmup
=
100
)
latency
=
profiler
.
do_bench
(
warmup
=
100
)
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
main
()
main
()
examples/lazy_jit/lazyjit.en.ipynb
View file @
29051439
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
"source": [
"source": [
"import sys\n",
"import sys\n",
"from pathlib import Path\n",
"from pathlib import Path\n",
"\n",
"sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n",
"sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n",
"import tilelang\n",
"import tilelang\n",
"import torch\n",
"import torch\n",
...
@@ -61,7 +62,7 @@
...
@@ -61,7 +62,7 @@
" out_dtype: T.dtype = T.float32,\n",
" out_dtype: T.dtype = T.float32,\n",
" block_M: int = 128,\n",
" block_M: int = 128,\n",
" block_N: int = 128,\n",
" block_N: int = 128,\n",
" block_K: int = 32\n",
" block_K: int = 32
,
\n",
"):\n",
"):\n",
" M, K = A.shape\n",
" M, K = A.shape\n",
" K, N = B.shape\n",
" K, N = B.shape\n",
...
@@ -94,8 +95,8 @@
...
@@ -94,8 +95,8 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm(A, B)\n",
"C = gemm(A, B)\n",
"\n",
"\n",
"# check output is correct\n",
"# check output is correct\n",
...
@@ -118,8 +119,8 @@
...
@@ -118,8 +119,8 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 1024, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 1024, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm(A, B, block_M=64, block_N=64)"
"C = gemm(A, B, block_M=64, block_N=64)"
]
]
},
},
...
@@ -218,8 +219,8 @@
...
@@ -218,8 +219,8 @@
"source": [
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.lazy_jit\n",
"def gemm_dyn_K(\n",
"def gemm_dyn_K(\n",
" A: T.Tensor[[int, T.dyn[
'K'
]], T.float16], # noqa: F821\n",
" A: T.Tensor[[int, T.dyn[
\"K\"
]], T.float16],
# noqa: F821\n",
" B: T.Tensor[[T.dyn[
'K'
], int], T.float16], # noqa: F821\n",
" B: T.Tensor[[T.dyn[
\"K\"
], int], T.float16],
# noqa: F821\n",
"):\n",
"):\n",
" M, K = A.shape\n",
" M, K = A.shape\n",
" K, N = B.shape\n",
" K, N = B.shape\n",
...
@@ -265,8 +266,8 @@
...
@@ -265,8 +266,8 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm_dyn_K(A, B)\n",
"C = gemm_dyn_K(A, B)\n",
"C_ref = (A @ B).float()\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...
@@ -295,18 +296,17 @@
...
@@ -295,18 +296,17 @@
"source": [
"source": [
"from typing import Any\n",
"from typing import Any\n",
"\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"@tilelang.lazy_jit\n",
"def as_contingious(\n",
"def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):\n",
" A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]\n",
"):\n",
" M, N = A.shape\n",
" M, N = A.shape\n",
" B = T.empty((M, N), A.dtype)\n",
" B = T.empty((M, N), A.dtype)\n",
" block_M = 128\n",
" block_M = 128\n",
" block_N = 128\n",
" block_N = 128\n",
" with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
" with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
" T.copy(\n",
" T.copy(\n",
" A[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N],\n",
" A[bx * block_M
: (bx + 1) * block_M, by * block_N
: (by + 1) * block_N],\n",
" B[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N]\n",
" B[bx * block_M
: (bx + 1) * block_M, by * block_N
: (by + 1) * block_N]
,
\n",
" )\n",
" )\n",
" return B"
" return B"
]
]
...
@@ -318,7 +318,7 @@
...
@@ -318,7 +318,7 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"A = torch.randn(1024, 1024, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 1024, device=
\"
cuda
\"
)\n",
"B = as_contingious(A[::2, ::2])\n",
"B = as_contingious(A[::2, ::2])\n",
"B_ref = A[::2, ::2].contiguous()\n",
"B_ref = A[::2, ::2].contiguous()\n",
"torch.testing.assert_close(B, B_ref)"
"torch.testing.assert_close(B, B_ref)"
...
@@ -370,8 +370,8 @@
...
@@ -370,8 +370,8 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm_ptr(A, B, 1024, 256, 512)\n",
"C = gemm_ptr(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...
@@ -416,8 +416,8 @@
...
@@ -416,8 +416,8 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n",
"C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...
@@ -496,18 +496,20 @@
...
@@ -496,18 +496,20 @@
"source": [
"source": [
"from itertools import product\n",
"from itertools import product\n",
"\n",
"\n",
"\n",
"def get_configs():\n",
"def get_configs():\n",
" return [\n",
" return [\n",
" {\n",
" {\n",
"
'A'
: T.Tensor((1024, 1024), T.float32),\n",
"
\"A\"
: T.Tensor((1024, 1024), T.float32),\n",
"
'B'
: T.Tensor((1024, 1024), T.float32),\n",
"
\"B\"
: T.Tensor((1024, 1024), T.float32),\n",
"
'
block_M
'
: block_M,\n",
"
\"
block_M
\"
: block_M,\n",
"
'
block_N
'
: block_N,\n",
"
\"
block_N
\"
: block_N,\n",
"
'
block_K
'
: block_K,\n",
"
\"
block_K
\"
: block_K,\n",
" }\n",
" }\n",
" for block_M, block_N, block_K in product([32, 64], repeat=3)\n",
" for block_M, block_N, block_K in product([32, 64], repeat=3)\n",
" ]\n",
" ]\n",
"\n",
"\n",
"\n",
"gemm.par_compile(get_configs())"
"gemm.par_compile(get_configs())"
]
]
},
},
...
@@ -579,7 +581,8 @@
...
@@ -579,7 +581,8 @@
"source": [
"source": [
"@T.macro\n",
"@T.macro\n",
"def macro_with_ref(x: T.Ref):\n",
"def macro_with_ref(x: T.Ref):\n",
" x = 1 # noqa: F841\n",
" x = 1 # noqa: F841\n",
"\n",
"\n",
"\n",
"@T.prim_func\n",
"@T.prim_func\n",
"def foo(x: T.Tensor((2,))):\n",
"def foo(x: T.Tensor((2,))):\n",
...
@@ -591,6 +594,7 @@
...
@@ -591,6 +594,7 @@
" idx = T.alloc_var(T.int32, 0)\n",
" idx = T.alloc_var(T.int32, 0)\n",
" macro_with_ref(x[idx])\n",
" macro_with_ref(x[idx])\n",
"\n",
"\n",
"\n",
"foo"
"foo"
]
]
},
},
...
@@ -616,7 +620,7 @@
...
@@ -616,7 +620,7 @@
" A: T.Tensor[[T.dyn], Any],\n",
" A: T.Tensor[[T.dyn], Any],\n",
" fn,\n",
" fn,\n",
"):\n",
"):\n",
" N, = A.shape\n",
"
(
N,
)
= A.shape\n",
" B = T.empty((N,), dtype=A.dtype)\n",
" B = T.empty((N,), dtype=A.dtype)\n",
" block_N = 128\n",
" block_N = 128\n",
" with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n",
" with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n",
...
@@ -624,6 +628,8 @@
...
@@ -624,6 +628,8 @@
" idx = bx * block_N + i\n",
" idx = bx * block_N + i\n",
" B[idx] = fn(A[idx])\n",
" B[idx] = fn(A[idx])\n",
" return B\n",
" return B\n",
"\n",
"\n",
"@T.macro\n",
"@T.macro\n",
"def add_one(x):\n",
"def add_one(x):\n",
" return x + 1"
" return x + 1"
...
@@ -636,7 +642,7 @@
...
@@ -636,7 +642,7 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"A = torch.randn(1024, device=
'
cuda
'
)\n",
"A = torch.randn(1024, device=
\"
cuda
\"
)\n",
"B = element_wise(A, add_one)\n",
"B = element_wise(A, add_one)\n",
"B_ref = A + 1\n",
"B_ref = A + 1\n",
"torch.testing.assert_close(B, B_ref)"
"torch.testing.assert_close(B, B_ref)"
...
@@ -670,10 +676,11 @@
...
@@ -670,10 +676,11 @@
" var = var * 3 + 1\n",
" var = var * 3 + 1\n",
" n31(x * 3 + 1, var)\n",
" n31(x * 3 + 1, var)\n",
"\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"@tilelang.lazy_jit\n",
"def foo(A: T.Tensor[[1], T.int32], n: int):\n",
"def foo(A: T.Tensor[[1], T.int32], n: int):\n",
" with T.Kernel(1) as _:\n",
" with T.Kernel(1) as _:\n",
" n31(n, A[0])
\n
"
" n31(n, A[0])"
]
]
},
},
{
{
...
@@ -694,7 +701,7 @@
...
@@ -694,7 +701,7 @@
}
}
],
],
"source": [
"source": [
"A = torch.tensor([100], dtype=torch.int32, device=
'
cuda
'
)\n",
"A = torch.tensor([100], dtype=torch.int32, device=
\"
cuda
\"
)\n",
"foo(A, 5)\n",
"foo(A, 5)\n",
"A"
"A"
]
]
...
@@ -745,12 +752,15 @@
...
@@ -745,12 +752,15 @@
"def sincos(x):\n",
"def sincos(x):\n",
" return T.sin(x), T.cos(x)\n",
" return T.sin(x), T.cos(x)\n",
"\n",
"\n",
"\n",
"@T.prim_func\n",
"@T.prim_func\n",
"def foo():\n",
"def foo():\n",
" with T.Kernel(32) as x:\n",
" with T.Kernel(32) as x:\n",
" s, c = sincos(x)\n",
" s, c = sincos(x)\n",
" a = s + c # noqa: F841\n",
" a = s + c # noqa: F841\n",
" b = s - c # noqa: F841\n",
" b = s - c # noqa: F841\n",
"\n",
"\n",
"foo"
"foo"
]
]
}
}
...
...
examples/lazy_jit/lazyjit.zh.ipynb
View file @
29051439
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
"source": [
"source": [
"import sys\n",
"import sys\n",
"from pathlib import Path\n",
"from pathlib import Path\n",
"\n",
"sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n",
"sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n",
"import tilelang\n",
"import tilelang\n",
"import torch\n",
"import torch\n",
...
@@ -61,7 +62,7 @@
...
@@ -61,7 +62,7 @@
" out_dtype: T.dtype = T.float32,\n",
" out_dtype: T.dtype = T.float32,\n",
" block_M: int = 128,\n",
" block_M: int = 128,\n",
" block_N: int = 128,\n",
" block_N: int = 128,\n",
" block_K: int = 32\n",
" block_K: int = 32
,
\n",
"):\n",
"):\n",
" M, K = A.shape\n",
" M, K = A.shape\n",
" K, N = B.shape\n",
" K, N = B.shape\n",
...
@@ -94,8 +95,8 @@
...
@@ -94,8 +95,8 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm(A, B)\n",
"C = gemm(A, B)\n",
"\n",
"\n",
"# check output is correct\n",
"# check output is correct\n",
...
@@ -118,8 +119,8 @@
...
@@ -118,8 +119,8 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 1024, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 1024, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm(A, B, block_M=64, block_N=64)"
"C = gemm(A, B, block_M=64, block_N=64)"
]
]
},
},
...
@@ -218,8 +219,8 @@
...
@@ -218,8 +219,8 @@
"source": [
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.lazy_jit\n",
"def gemm_dyn_K(\n",
"def gemm_dyn_K(\n",
" A: T.Tensor[[int, T.dyn[
'K'
]], T.float16], # noqa: F821\n",
" A: T.Tensor[[int, T.dyn[
\"K\"
]], T.float16],
# noqa: F821\n",
" B: T.Tensor[[T.dyn[
'K'
], int], T.float16], # noqa: F821\n",
" B: T.Tensor[[T.dyn[
\"K\"
], int], T.float16],
# noqa: F821\n",
"):\n",
"):\n",
" M, K = A.shape\n",
" M, K = A.shape\n",
" K, N = B.shape\n",
" K, N = B.shape\n",
...
@@ -265,8 +266,8 @@
...
@@ -265,8 +266,8 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm_dyn_K(A, B)\n",
"C = gemm_dyn_K(A, B)\n",
"C_ref = (A @ B).float()\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...
@@ -295,18 +296,17 @@
...
@@ -295,18 +296,17 @@
"source": [
"source": [
"from typing import Any\n",
"from typing import Any\n",
"\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"@tilelang.lazy_jit\n",
"def as_contingious(\n",
"def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):\n",
" A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]\n",
"):\n",
" M, N = A.shape\n",
" M, N = A.shape\n",
" B = T.empty((M, N), A.dtype)\n",
" B = T.empty((M, N), A.dtype)\n",
" block_M = 128\n",
" block_M = 128\n",
" block_N = 128\n",
" block_N = 128\n",
" with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
" with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
" T.copy(\n",
" T.copy(\n",
" A[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N],\n",
" A[bx * block_M
: (bx + 1) * block_M, by * block_N
: (by + 1) * block_N],\n",
" B[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N]\n",
" B[bx * block_M
: (bx + 1) * block_M, by * block_N
: (by + 1) * block_N]
,
\n",
" )\n",
" )\n",
" return B"
" return B"
]
]
...
@@ -318,7 +318,7 @@
...
@@ -318,7 +318,7 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"A = torch.randn(1024, 1024, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 1024, device=
\"
cuda
\"
)\n",
"B = as_contingious(A[::2, ::2])\n",
"B = as_contingious(A[::2, ::2])\n",
"B_ref = A[::2, ::2].contiguous()\n",
"B_ref = A[::2, ::2].contiguous()\n",
"torch.testing.assert_close(B, B_ref)"
"torch.testing.assert_close(B, B_ref)"
...
@@ -370,8 +370,8 @@
...
@@ -370,8 +370,8 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm_ptr(A, B, 1024, 256, 512)\n",
"C = gemm_ptr(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...
@@ -416,8 +416,8 @@
...
@@ -416,8 +416,8 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device=
'
cuda
'
)\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
'
cuda
'
)\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=
\"
cuda
\"
)\n",
"C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n",
"C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...
@@ -496,18 +496,20 @@
...
@@ -496,18 +496,20 @@
"source": [
"source": [
"from itertools import product\n",
"from itertools import product\n",
"\n",
"\n",
"\n",
"def get_configs():\n",
"def get_configs():\n",
" return [\n",
" return [\n",
" {\n",
" {\n",
"
'A'
: T.Tensor((1024, 1024), T.float32),\n",
"
\"A\"
: T.Tensor((1024, 1024), T.float32),\n",
"
'B'
: T.Tensor((1024, 1024), T.float32),\n",
"
\"B\"
: T.Tensor((1024, 1024), T.float32),\n",
"
'
block_M
'
: block_M,\n",
"
\"
block_M
\"
: block_M,\n",
"
'
block_N
'
: block_N,\n",
"
\"
block_N
\"
: block_N,\n",
"
'
block_K
'
: block_K,\n",
"
\"
block_K
\"
: block_K,\n",
" }\n",
" }\n",
" for block_M, block_N, block_K in product([32, 64], repeat=3)\n",
" for block_M, block_N, block_K in product([32, 64], repeat=3)\n",
" ]\n",
" ]\n",
"\n",
"\n",
"\n",
"gemm.par_compile(get_configs())"
"gemm.par_compile(get_configs())"
]
]
},
},
...
@@ -579,7 +581,8 @@
...
@@ -579,7 +581,8 @@
"source": [
"source": [
"@T.macro\n",
"@T.macro\n",
"def macro_with_ref(x: T.Ref):\n",
"def macro_with_ref(x: T.Ref):\n",
" x = 1 # noqa: F841\n",
" x = 1 # noqa: F841\n",
"\n",
"\n",
"\n",
"@T.prim_func\n",
"@T.prim_func\n",
"def foo(x: T.Tensor((2,))):\n",
"def foo(x: T.Tensor((2,))):\n",
...
@@ -591,6 +594,7 @@
...
@@ -591,6 +594,7 @@
" idx = T.alloc_var(T.int32, 0)\n",
" idx = T.alloc_var(T.int32, 0)\n",
" macro_with_ref(x[idx])\n",
" macro_with_ref(x[idx])\n",
"\n",
"\n",
"\n",
"foo"
"foo"
]
]
},
},
...
@@ -616,7 +620,7 @@
...
@@ -616,7 +620,7 @@
" A: T.Tensor[[T.dyn], Any],\n",
" A: T.Tensor[[T.dyn], Any],\n",
" fn,\n",
" fn,\n",
"):\n",
"):\n",
" N, = A.shape\n",
"
(
N,
)
= A.shape\n",
" B = T.empty((N,), dtype=A.dtype)\n",
" B = T.empty((N,), dtype=A.dtype)\n",
" block_N = 128\n",
" block_N = 128\n",
" with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n",
" with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n",
...
@@ -624,6 +628,8 @@
...
@@ -624,6 +628,8 @@
" idx = bx * block_N + i\n",
" idx = bx * block_N + i\n",
" B[idx] = fn(A[idx])\n",
" B[idx] = fn(A[idx])\n",
" return B\n",
" return B\n",
"\n",
"\n",
"@T.macro\n",
"@T.macro\n",
"def add_one(x):\n",
"def add_one(x):\n",
" return x + 1"
" return x + 1"
...
@@ -636,7 +642,7 @@
...
@@ -636,7 +642,7 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"A = torch.randn(1024, device=
'
cuda
'
)\n",
"A = torch.randn(1024, device=
\"
cuda
\"
)\n",
"B = element_wise(A, add_one)\n",
"B = element_wise(A, add_one)\n",
"B_ref = A + 1\n",
"B_ref = A + 1\n",
"torch.testing.assert_close(B, B_ref)"
"torch.testing.assert_close(B, B_ref)"
...
@@ -670,10 +676,11 @@
...
@@ -670,10 +676,11 @@
" var = var * 3 + 1\n",
" var = var * 3 + 1\n",
" n31(x * 3 + 1, var)\n",
" n31(x * 3 + 1, var)\n",
"\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"@tilelang.lazy_jit\n",
"def foo(A: T.Tensor[[1], T.int32], n: int):\n",
"def foo(A: T.Tensor[[1], T.int32], n: int):\n",
" with T.Kernel(1) as _:\n",
" with T.Kernel(1) as _:\n",
" n31(n, A[0])
\n
"
" n31(n, A[0])"
]
]
},
},
{
{
...
@@ -694,7 +701,7 @@
...
@@ -694,7 +701,7 @@
}
}
],
],
"source": [
"source": [
"A = torch.tensor([100], dtype=torch.int32, device=
'
cuda
'
)\n",
"A = torch.tensor([100], dtype=torch.int32, device=
\"
cuda
\"
)\n",
"foo(A, 5)\n",
"foo(A, 5)\n",
"A"
"A"
]
]
...
@@ -745,12 +752,15 @@
...
@@ -745,12 +752,15 @@
"def sincos(x):\n",
"def sincos(x):\n",
" return T.sin(x), T.cos(x)\n",
" return T.sin(x), T.cos(x)\n",
"\n",
"\n",
"\n",
"@T.prim_func\n",
"@T.prim_func\n",
"def foo():\n",
"def foo():\n",
" with T.Kernel(32) as x:\n",
" with T.Kernel(32) as x:\n",
" s, c = sincos(x)\n",
" s, c = sincos(x)\n",
" a = s + c # noqa: F841\n",
" a = s + c # noqa: F841\n",
" b = s - c # noqa: F841\n",
" b = s - c # noqa: F841\n",
"\n",
"\n",
"foo"
"foo"
]
]
}
}
...
...
examples/linear_attention/example_linear_attn_bwd.py
View file @
29051439
...
@@ -13,20 +13,20 @@ from typing import Optional, Tuple
...
@@ -13,20 +13,20 @@ from typing import Optional, Tuple
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
}
)
def
tl_fused_chunk_bwd_kernel
(
def
tl_fused_chunk_bwd_kernel
(
B
,
B
,
S
,
S
,
H
,
H
,
DK
,
DK
,
DV
,
DV
,
dtype
:
str
=
'
float16
'
,
dtype
:
str
=
"
float16
"
,
scale
:
float
=
None
,
scale
:
float
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
scale
is
None
:
if
scale
is
None
:
scale
=
DK
**-
0.5
scale
=
DK
**-
0.5
accum_dtype
=
'
float
'
accum_dtype
=
"
float
"
chunk_size
=
64
chunk_size
=
64
BK
=
BV
=
64
# Set to 128 can be faster, but has some numerical differences with FLA
BK
=
BV
=
64
# Set to 128 can be faster, but has some numerical differences with FLA
...
@@ -37,13 +37,13 @@ def tl_fused_chunk_bwd_kernel(
...
@@ -37,13 +37,13 @@ def tl_fused_chunk_bwd_kernel(
@
T
.
prim_func
@
T
.
prim_func
def
fused_chunk_linear_attn_bwd
(
def
fused_chunk_linear_attn_bwd
(
Q
:
T
.
Tensor
([
B
,
S
,
H
,
DK
],
dtype
),
# type: ignore
Q
:
T
.
Tensor
([
B
,
S
,
H
,
DK
],
dtype
),
# type: ignore
K
:
T
.
Tensor
([
B
,
S
,
H
,
DK
],
dtype
),
# type: ignore
K
:
T
.
Tensor
([
B
,
S
,
H
,
DK
],
dtype
),
# type: ignore
V
:
T
.
Tensor
([
B
,
S
,
H
,
DV
],
dtype
),
# type: ignore
V
:
T
.
Tensor
([
B
,
S
,
H
,
DV
],
dtype
),
# type: ignore
dO
:
T
.
Tensor
([
B
,
S
,
H
,
DV
],
dtype
),
# type: ignore
dO
:
T
.
Tensor
([
B
,
S
,
H
,
DV
],
dtype
),
# type: ignore
dQ
:
T
.
Tensor
([
B
,
S
,
H
,
DK
],
accum_dtype
),
# type: ignore
dQ
:
T
.
Tensor
([
B
,
S
,
H
,
DK
],
accum_dtype
),
# type: ignore
dK
:
T
.
Tensor
([
B
,
S
,
H
,
DK
],
accum_dtype
),
# type: ignore
dK
:
T
.
Tensor
([
B
,
S
,
H
,
DK
],
accum_dtype
),
# type: ignore
dV
:
T
.
Tensor
([
B
,
S
,
H
,
DV
],
accum_dtype
),
# type: ignore
dV
:
T
.
Tensor
([
B
,
S
,
H
,
DV
],
accum_dtype
),
# type: ignore
):
):
with
T
.
Kernel
(
NV
,
NK
,
B
*
H
)
as
(
i_v
,
i_k
,
i_bh
):
with
T
.
Kernel
(
NV
,
NK
,
B
*
H
)
as
(
i_v
,
i_k
,
i_bh
):
i_b
=
i_bh
//
H
i_b
=
i_bh
//
H
...
@@ -66,11 +66,13 @@ def tl_fused_chunk_bwd_kernel(
...
@@ -66,11 +66,13 @@ def tl_fused_chunk_bwd_kernel(
dh
=
T
.
alloc_fragment
([
BK
,
BV
],
accum_dtype
)
dh
=
T
.
alloc_fragment
([
BK
,
BV
],
accum_dtype
)
dh_shared
=
T
.
alloc_shared
([
BK
,
BV
],
dtype
)
dh_shared
=
T
.
alloc_shared
([
BK
,
BV
],
dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
dq_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dq_shared
),
{
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
dq_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dq_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
)
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
})
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
}
)
T
.
use_swizzle
(
10
)
T
.
use_swizzle
(
10
)
T
.
clear
(
h
)
T
.
clear
(
h
)
...
@@ -78,10 +80,9 @@ def tl_fused_chunk_bwd_kernel(
...
@@ -78,10 +80,9 @@ def tl_fused_chunk_bwd_kernel(
# Calculate dQ
# Calculate dQ
for
i
in
T
.
Pipelined
(
0
,
NT
):
for
i
in
T
.
Pipelined
(
0
,
NT
):
T
.
copy
(
K
[
i_b
,
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
K
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
V
[
i_b
,
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
],
v
)
T
.
copy
(
V
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
v
)
T
.
copy
(
dO
[
i_b
,
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
],
T
.
copy
(
dO
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
do
)
do
)
T
.
gemm
(
do
,
v
,
ds
,
transpose_B
=
True
,
clear_accum
=
True
)
T
.
gemm
(
do
,
v
,
ds
,
transpose_B
=
True
,
clear_accum
=
True
)
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
chunk_size
):
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
chunk_size
):
...
@@ -94,29 +95,19 @@ def tl_fused_chunk_bwd_kernel(
...
@@ -94,29 +95,19 @@ def tl_fused_chunk_bwd_kernel(
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
BK
):
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
BK
):
dq
[
row
,
col
]
*=
scale
dq
[
row
,
col
]
*=
scale
T
.
copy
(
dq
,
dq_shared
)
T
.
copy
(
dq
,
dq_shared
)
T
.
atomic_add
(
T
.
atomic_add
(
dQ
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
dq_shared
)
dQ
[
i_b
,
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:(
i_k
+
1
)
*
BK
],
dq_shared
)
# Calculate dK, dV (reversely)
# Calculate dK, dV (reversely)
for
i
in
T
.
Pipelined
(
1
,
NT
+
1
):
for
i
in
T
.
Pipelined
(
1
,
NT
+
1
):
start
=
NT
-
i
start
=
NT
-
i
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
BK
):
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
BK
):
q
[
row
,
col
]
=
Q
[
i_b
,
start
*
chunk_size
+
row
,
i_h
,
i_k
*
BK
+
col
]
*
scale
q
[
row
,
col
]
=
Q
[
i_b
,
start
*
chunk_size
+
row
,
i_h
,
i_k
*
BK
+
col
]
*
scale
T
.
copy
(
T
.
copy
(
K
[
i_b
,
start
*
chunk_size
:
(
start
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
k
)
K
[
i_b
,
start
*
chunk_size
:(
start
+
1
)
*
chunk_size
,
i_h
,
T
.
copy
(
V
[
i_b
,
start
*
chunk_size
:
(
start
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
v
)
i_k
*
BK
:(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
dO
[
i_b
,
start
*
chunk_size
:
(
start
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
do
)
T
.
copy
(
V
[
i_b
,
start
*
chunk_size
:(
start
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
],
v
)
T
.
copy
(
dO
[
i_b
,
start
*
chunk_size
:(
start
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
],
do
)
# Calculate dk
# Calculate dk
T
.
gemm
(
T
.
gemm
(
v
,
do
,
ds
,
transpose_B
=
True
,
clear_accum
=
True
)
# ds here actually means `s`, but we simply reuse the buffer `ds`
v
,
do
,
ds
,
transpose_B
=
True
,
clear_accum
=
True
)
# ds here actually means `s`, but we simply reuse the buffer `ds`
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
chunk_size
):
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
chunk_size
):
ds_shared
[
row
,
col
]
=
T
.
if_then_else
(
row
<=
col
,
ds
[
row
,
col
],
0
)
ds_shared
[
row
,
col
]
=
T
.
if_then_else
(
row
<=
col
,
ds
[
row
,
col
],
0
)
T
.
gemm
(
ds_shared
,
q
,
dk
,
clear_accum
=
True
)
T
.
gemm
(
ds_shared
,
q
,
dk
,
clear_accum
=
True
)
...
@@ -134,13 +125,9 @@ def tl_fused_chunk_bwd_kernel(
...
@@ -134,13 +125,9 @@ def tl_fused_chunk_bwd_kernel(
T
.
gemm
(
q
,
do
,
dh
,
transpose_A
=
True
)
T
.
gemm
(
q
,
do
,
dh
,
transpose_A
=
True
)
T
.
copy
(
dk
,
dk_shared
)
T
.
copy
(
dk
,
dk_shared
)
T
.
atomic_add
(
T
.
atomic_add
(
dK
[
i_b
,
start
*
chunk_size
:
(
start
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
dk_shared
)
dK
[
i_b
,
start
*
chunk_size
:(
start
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:(
i_k
+
1
)
*
BK
],
dk_shared
)
T
.
copy
(
dv
,
dv_shared
)
T
.
copy
(
dv
,
dv_shared
)
T
.
atomic_add
(
T
.
atomic_add
(
dV
[
i_b
,
start
*
chunk_size
:
(
start
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
dv_shared
)
dV
[
i_b
,
start
*
chunk_size
:(
start
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
],
dv_shared
)
return
fused_chunk_linear_attn_bwd
return
fused_chunk_linear_attn_bwd
...
@@ -155,34 +142,31 @@ def tl_fused_chunk_bwd(Q, K, V, dO):
...
@@ -155,34 +142,31 @@ def tl_fused_chunk_bwd(Q, K, V, dO):
return
dQ
.
to
(
torch
.
float16
),
dK
.
to
(
torch
.
float16
),
dV
.
to
(
torch
.
float16
)
return
dQ
.
to
(
torch
.
float16
),
dK
.
to
(
torch
.
float16
),
dV
.
to
(
torch
.
float16
)
def
ref_program
(
q
:
torch
.
Tensor
,
def
ref_program
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
scale
:
Optional
[
float
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
scale
:
Optional
[
float
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
if
scale
is
None
:
if
scale
is
None
:
scale
=
q
.
shape
[
-
1
]
**-
0.5
scale
=
q
.
shape
[
-
1
]
**
-
0.5
chunk_size
=
64
chunk_size
=
64
q
=
rearrange
(
q
,
'
b (n c) h d -> b h n c d
'
,
c
=
chunk_size
)
*
scale
q
=
rearrange
(
q
,
"
b (n c) h d -> b h n c d
"
,
c
=
chunk_size
)
*
scale
k
=
rearrange
(
k
,
'
b (n c) h d -> b h n c d
'
,
c
=
chunk_size
)
k
=
rearrange
(
k
,
"
b (n c) h d -> b h n c d
"
,
c
=
chunk_size
)
v
=
rearrange
(
v
,
'
b (n c) h d -> b h n c d
'
,
c
=
chunk_size
)
v
=
rearrange
(
v
,
"
b (n c) h d -> b h n c d
"
,
c
=
chunk_size
)
kv
=
k
.
transpose
(
-
1
,
-
2
)
@
v
kv
=
k
.
transpose
(
-
1
,
-
2
)
@
v
kv
=
kv
.
cumsum
(
2
)
kv
=
kv
.
cumsum
(
2
)
h
=
kv
[:,
:,
-
1
,
:,
:]
h
=
kv
[:,
:,
-
1
,
:,
:]
kv
=
torch
.
cat
([
torch
.
zeros_like
(
kv
[:,
:,
:
1
]),
kv
[:,
:,
:
-
1
]],
dim
=
2
)
kv
=
torch
.
cat
([
torch
.
zeros_like
(
kv
[:,
:,
:
1
]),
kv
[:,
:,
:
-
1
]],
dim
=
2
)
inter
=
q
@
kv
inter
=
q
@
kv
intra
=
(
(
q
@
k
.
transpose
(
-
1
,
-
2
)).
masked_fill_
(
intra
=
(
torch
.
triu
(
torch
.
ones
(
chunk_size
,
chunk_size
,
dtype
=
bool
,
device
=
q
.
device
),
diagonal
=
1
),
(
q
@
k
.
transpose
(
-
1
,
-
2
)).
masked_fill_
(
torch
.
triu
(
torch
.
ones
(
chunk_size
,
chunk_size
,
dtype
=
bool
,
device
=
q
.
device
),
diagonal
=
1
),
0
)
0
)
)
@
v
)
@
v
o
=
inter
+
intra
o
=
inter
+
intra
return
rearrange
(
o
,
'
b h n c d -> b (n c) h d
'
),
h
return
rearrange
(
o
,
"
b h n c d -> b (n c) h d
"
),
h
def
main
(
B
=
1
,
S
=
1024
,
H
=
16
,
D
=
128
):
def
main
(
B
=
1
,
S
=
1024
,
H
=
16
,
D
=
128
):
q
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
,
requires_grad
=
True
)
q
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
,
requires_grad
=
True
)
k
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
,
requires_grad
=
True
)
k
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
,
requires_grad
=
True
)
v
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
,
requires_grad
=
True
)
v
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
,
requires_grad
=
True
)
do
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
do
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
# qk norm is necessary for linear attn
# qk norm is necessary for linear attn
q
=
l2norm_fwd
(
q
)[
0
].
requires_grad_
(
True
)
q
=
l2norm_fwd
(
q
)[
0
].
requires_grad_
(
True
)
...
@@ -193,30 +177,27 @@ def main(B=1, S=1024, H=16, D=128):
...
@@ -193,30 +177,27 @@ def main(B=1, S=1024, H=16, D=128):
o_ref
,
_
=
ref_program
(
q
,
k
,
v
)
o_ref
,
_
=
ref_program
(
q
,
k
,
v
)
o_ref
.
backward
(
do
,
retain_graph
=
True
)
o_ref
.
backward
(
do
,
retain_graph
=
True
)
assert
torch
.
allclose
(
assert
torch
.
allclose
(
dq
,
q
.
grad
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"dq max err:
{
(
dq
-
q
.
grad
).
abs
().
max
()
}
"
dq
,
q
.
grad
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
'dq max err:
{
(
dq
-
q
.
grad
).
abs
().
max
()
}
'
assert
torch
.
allclose
(
dk
,
k
.
grad
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"dk max err:
{
(
dk
-
k
.
grad
).
abs
().
max
()
}
"
assert
torch
.
allclose
(
assert
torch
.
allclose
(
dv
,
v
.
grad
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"dv max err:
{
(
dv
-
v
.
grad
).
abs
().
max
()
}
"
dk
,
k
.
grad
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
'dk max err:
{
(
dk
-
k
.
grad
).
abs
().
max
()
}
'
print
(
"Passed all tests!✅"
)
assert
torch
.
allclose
(
dv
,
v
.
grad
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
'dv max err:
{
(
dv
-
v
.
grad
).
abs
().
max
()
}
'
print
(
'Passed all tests!✅'
)
# Benchmark
# Benchmark
q
.
grad
=
k
.
grad
=
v
.
grad
=
None
q
.
grad
=
k
.
grad
=
v
.
grad
=
None
o_ref
,
_
=
fused_chunk_linear_attn
(
q
,
k
,
v
,
output_final_state
=
True
,
normalize
=
False
)
o_ref
,
_
=
fused_chunk_linear_attn
(
q
,
k
,
v
,
output_final_state
=
True
,
normalize
=
False
)
t1
=
do_bench
(
lambda
:
o_ref
.
backward
(
do
,
retain_graph
=
True
),
backend
=
'
cupti
'
)
t1
=
do_bench
(
lambda
:
o_ref
.
backward
(
do
,
retain_graph
=
True
),
backend
=
"
cupti
"
)
t2
=
do_bench
(
lambda
:
tl_fused_chunk_bwd
(
q
,
k
,
v
,
do
),
backend
=
'
cupti
'
)
t2
=
do_bench
(
lambda
:
tl_fused_chunk_bwd
(
q
,
k
,
v
,
do
),
backend
=
"
cupti
"
)
print
(
f
'
Triton latency:
{
t1
:.
3
f
}
ms
'
)
print
(
f
"
Triton latency:
{
t1
:.
3
f
}
ms
"
)
print
(
f
'
TileLang latency:
{
t2
:.
3
f
}
ms
'
)
print
(
f
"
TileLang latency:
{
t2
:.
3
f
}
ms
"
)
print
(
f
'
Speedup:
{
t1
/
t2
:.
3
f
}
x
'
)
print
(
f
"
Speedup:
{
t1
/
t2
:.
3
f
}
x
"
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--B
'
,
type
=
int
,
default
=
8
,
help
=
'
Batch size
'
)
parser
.
add_argument
(
"
--B
"
,
type
=
int
,
default
=
8
,
help
=
"
Batch size
"
)
parser
.
add_argument
(
'
--S
'
,
type
=
int
,
default
=
1024
,
help
=
'
Seq len
'
)
parser
.
add_argument
(
"
--S
"
,
type
=
int
,
default
=
1024
,
help
=
"
Seq len
"
)
parser
.
add_argument
(
'
--H
'
,
type
=
int
,
default
=
32
,
help
=
'
Num heads
'
)
parser
.
add_argument
(
"
--H
"
,
type
=
int
,
default
=
32
,
help
=
"
Num heads
"
)
parser
.
add_argument
(
'
--D
'
,
type
=
int
,
default
=
128
,
help
=
'
Head dim
'
)
parser
.
add_argument
(
"
--D
"
,
type
=
int
,
default
=
128
,
help
=
"
Head dim
"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
B
,
args
.
S
,
args
.
H
,
args
.
D
)
main
(
args
.
B
,
args
.
S
,
args
.
H
,
args
.
D
)
examples/linear_attention/example_linear_attn_fwd.py
View file @
29051439
...
@@ -14,20 +14,20 @@ from typing import Optional, Tuple
...
@@ -14,20 +14,20 @@ from typing import Optional, Tuple
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
def
tl_fused_chunk_fwd_kernel
(
def
tl_fused_chunk_fwd_kernel
(
B
,
B
,
S
,
S
,
H
,
H
,
DK
,
DK
,
DV
,
DV
,
dtype
:
str
=
'
float16
'
,
dtype
:
str
=
"
float16
"
,
scale
:
float
=
None
,
scale
:
float
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
scale
is
None
:
if
scale
is
None
:
scale
=
DK
**-
0.5
scale
=
DK
**-
0.5
accum_dtype
=
'
float
'
accum_dtype
=
"
float
"
chunk_size
=
64
chunk_size
=
64
BK
=
BV
=
64
# Set to 128 can be faster, but has some numerical differences with FLA
BK
=
BV
=
64
# Set to 128 can be faster, but has some numerical differences with FLA
...
@@ -38,11 +38,12 @@ def tl_fused_chunk_fwd_kernel(
...
@@ -38,11 +38,12 @@ def tl_fused_chunk_fwd_kernel(
@
T
.
prim_func
@
T
.
prim_func
def
fused_chunk_linear_attn_fwd
(
def
fused_chunk_linear_attn_fwd
(
Q
:
T
.
Tensor
([
B
,
S
,
H
,
DK
],
dtype
),
# type: ignore
Q
:
T
.
Tensor
([
B
,
S
,
H
,
DK
],
dtype
),
# type: ignore
K
:
T
.
Tensor
([
B
,
S
,
H
,
DK
],
dtype
),
# type: ignore
K
:
T
.
Tensor
([
B
,
S
,
H
,
DK
],
dtype
),
# type: ignore
V
:
T
.
Tensor
([
B
,
S
,
H
,
DV
],
dtype
),
# type: ignore
V
:
T
.
Tensor
([
B
,
S
,
H
,
DV
],
dtype
),
# type: ignore
O
:
T
.
Tensor
([
B
,
S
,
H
,
DV
],
accum_dtype
),
# type: ignore
O
:
T
.
Tensor
([
B
,
S
,
H
,
DV
],
accum_dtype
),
# type: ignore
final_state
:
T
.
Tensor
([
B
,
H
,
DK
,
DV
],
accum_dtype
)):
# type: ignore
final_state
:
T
.
Tensor
([
B
,
H
,
DK
,
DV
],
accum_dtype
),
):
# type: ignore
with
T
.
Kernel
(
NV
,
NK
,
B
*
H
)
as
(
i_v
,
i_k
,
i_bh
):
with
T
.
Kernel
(
NV
,
NK
,
B
*
H
)
as
(
i_v
,
i_k
,
i_bh
):
i_b
=
i_bh
//
H
i_b
=
i_bh
//
H
i_h
=
i_bh
%
H
i_h
=
i_bh
%
H
...
@@ -65,8 +66,8 @@ def tl_fused_chunk_fwd_kernel(
...
@@ -65,8 +66,8 @@ def tl_fused_chunk_fwd_kernel(
for
i
in
T
.
Pipelined
(
0
,
NT
):
for
i
in
T
.
Pipelined
(
0
,
NT
):
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
BK
):
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
BK
):
q
[
row
,
col
]
=
Q
[
i_b
,
i
*
chunk_size
+
row
,
i_h
,
i_k
*
BK
+
col
]
*
scale
q
[
row
,
col
]
=
Q
[
i_b
,
i
*
chunk_size
+
row
,
i_h
,
i_k
*
BK
+
col
]
*
scale
T
.
copy
(
K
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
K
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
V
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
v
)
T
.
copy
(
V
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
v
)
T
.
gemm
(
q
,
k
,
s
,
clear_accum
=
True
,
transpose_B
=
True
)
T
.
gemm
(
q
,
k
,
s
,
clear_accum
=
True
,
transpose_B
=
True
)
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
chunk_size
):
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
chunk_size
):
...
@@ -77,12 +78,10 @@ def tl_fused_chunk_fwd_kernel(
...
@@ -77,12 +78,10 @@ def tl_fused_chunk_fwd_kernel(
T
.
gemm
(
k
,
v
,
h
,
transpose_A
=
True
)
T
.
gemm
(
k
,
v
,
h
,
transpose_A
=
True
)
T
.
gemm
(
q
,
h_shared
,
o
)
T
.
gemm
(
q
,
h_shared
,
o
)
T
.
copy
(
o
,
o_shared
)
T
.
copy
(
o
,
o_shared
)
T
.
atomic_add
(
T
.
atomic_add
(
O
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
o_shared
)
O
[
i_b
,
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
],
o_shared
)
# Output final state
# Output final state
T
.
copy
(
h
,
final_state
[
i_b
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
])
T
.
copy
(
h
,
final_state
[
i_b
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
])
return
fused_chunk_linear_attn_fwd
return
fused_chunk_linear_attn_fwd
...
@@ -91,38 +90,35 @@ def tl_fused_chunk_fwd(q, k, v):
...
@@ -91,38 +90,35 @@ def tl_fused_chunk_fwd(q, k, v):
B
,
S
,
H
,
D
=
q
.
shape
B
,
S
,
H
,
D
=
q
.
shape
kernel
=
tl_fused_chunk_fwd_kernel
(
B
,
S
,
H
,
D
,
D
)
kernel
=
tl_fused_chunk_fwd_kernel
(
B
,
S
,
H
,
D
,
D
)
print
(
kernel
.
get_kernel_source
())
print
(
kernel
.
get_kernel_source
())
o
=
torch
.
zeros
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float32
)
o
=
torch
.
zeros
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float32
)
h
=
kernel
(
q
,
k
,
v
,
o
)
h
=
kernel
(
q
,
k
,
v
,
o
)
return
o
,
h
return
o
,
h
def
ref_program
(
q
:
torch
.
Tensor
,
def
ref_program
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
scale
:
Optional
[
float
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
scale
:
Optional
[
float
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
if
scale
is
None
:
if
scale
is
None
:
scale
=
q
.
shape
[
-
1
]
**-
0.5
scale
=
q
.
shape
[
-
1
]
**
-
0.5
chunk_size
=
64
chunk_size
=
64
q
=
rearrange
(
q
,
'
b (n c) h d -> b h n c d
'
,
c
=
chunk_size
)
*
scale
q
=
rearrange
(
q
,
"
b (n c) h d -> b h n c d
"
,
c
=
chunk_size
)
*
scale
k
=
rearrange
(
k
,
'
b (n c) h d -> b h n c d
'
,
c
=
chunk_size
)
k
=
rearrange
(
k
,
"
b (n c) h d -> b h n c d
"
,
c
=
chunk_size
)
v
=
rearrange
(
v
,
'
b (n c) h d -> b h n c d
'
,
c
=
chunk_size
)
v
=
rearrange
(
v
,
"
b (n c) h d -> b h n c d
"
,
c
=
chunk_size
)
kv
=
k
.
transpose
(
-
1
,
-
2
)
@
v
kv
=
k
.
transpose
(
-
1
,
-
2
)
@
v
kv
=
kv
.
cumsum
(
2
)
kv
=
kv
.
cumsum
(
2
)
h
=
kv
[:,
:,
-
1
,
:,
:]
h
=
kv
[:,
:,
-
1
,
:,
:]
kv
=
torch
.
cat
([
torch
.
zeros_like
(
kv
[:,
:,
:
1
]),
kv
[:,
:,
:
-
1
]],
dim
=
2
)
kv
=
torch
.
cat
([
torch
.
zeros_like
(
kv
[:,
:,
:
1
]),
kv
[:,
:,
:
-
1
]],
dim
=
2
)
inter
=
q
@
kv
inter
=
q
@
kv
intra
=
(
(
q
@
k
.
transpose
(
-
1
,
-
2
)).
masked_fill_
(
intra
=
(
torch
.
triu
(
torch
.
ones
(
chunk_size
,
chunk_size
,
dtype
=
bool
,
device
=
q
.
device
),
diagonal
=
1
),
(
q
@
k
.
transpose
(
-
1
,
-
2
)).
masked_fill_
(
torch
.
triu
(
torch
.
ones
(
chunk_size
,
chunk_size
,
dtype
=
bool
,
device
=
q
.
device
),
diagonal
=
1
),
0
)
0
)
)
@
v
)
@
v
o
=
inter
+
intra
o
=
inter
+
intra
return
rearrange
(
o
,
'
b h n c d -> b (n c) h d
'
),
h
return
rearrange
(
o
,
"
b h n c d -> b (n c) h d
"
),
h
def
main
(
B
=
1
,
S
=
512
,
H
=
16
,
D
=
128
):
def
main
(
B
=
1
,
S
=
512
,
H
=
16
,
D
=
128
):
q
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
# qk norm is necessary for linear attn
# qk norm is necessary for linear attn
q
,
_
=
l2norm_fwd
(
q
)
q
,
_
=
l2norm_fwd
(
q
)
...
@@ -131,25 +127,23 @@ def main(B=1, S=512, H=16, D=128):
...
@@ -131,25 +127,23 @@ def main(B=1, S=512, H=16, D=128):
o
,
h
=
tl_fused_chunk_fwd
(
q
,
k
,
v
)
o
,
h
=
tl_fused_chunk_fwd
(
q
,
k
,
v
)
o_ref
,
h_ref
=
ref_program
(
q
,
k
,
v
)
o_ref
,
h_ref
=
ref_program
(
q
,
k
,
v
)
assert
torch
.
allclose
(
o
,
o_ref
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
'
o max err:
{
(
o
-
o_ref
).
abs
().
max
()
}
'
assert
torch
.
allclose
(
o
,
o_ref
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"
o max err:
{
(
o
-
o_ref
).
abs
().
max
()
}
"
assert
torch
.
allclose
(
h
,
h_ref
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
'
h max err:
{
(
h
-
h_ref
).
abs
().
max
()
}
'
assert
torch
.
allclose
(
h
,
h_ref
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"
h max err:
{
(
h
-
h_ref
).
abs
().
max
()
}
"
print
(
'
Passed all tests!✅
'
)
print
(
"
Passed all tests!✅
"
)
t1
=
do_bench
(
t1
=
do_bench
(
lambda
:
fused_chunk_linear_attn
(
q
,
k
,
v
,
output_final_state
=
True
,
normalize
=
False
),
backend
=
"cupti"
)
lambda
:
fused_chunk_linear_attn
(
q
,
k
,
v
,
output_final_state
=
True
,
normalize
=
False
),
t2
=
do_bench
(
lambda
:
tl_fused_chunk_fwd
(
q
,
k
,
v
),
backend
=
"cupti"
)
backend
=
'cupti'
)
print
(
f
"Triton latency:
{
t1
:.
3
f
}
ms"
)
t2
=
do_bench
(
lambda
:
tl_fused_chunk_fwd
(
q
,
k
,
v
),
backend
=
'cupti'
)
print
(
f
"TileLang latency:
{
t2
:.
3
f
}
ms"
)
print
(
f
'Triton latency:
{
t1
:.
3
f
}
ms'
)
print
(
f
"Speedup:
{
t1
/
t2
:.
3
f
}
x"
)
print
(
f
'TileLang latency:
{
t2
:.
3
f
}
ms'
)
print
(
f
'Speedup:
{
t1
/
t2
:.
3
f
}
x'
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--B
'
,
type
=
int
,
default
=
8
,
help
=
'
Batch size
'
)
parser
.
add_argument
(
"
--B
"
,
type
=
int
,
default
=
8
,
help
=
"
Batch size
"
)
parser
.
add_argument
(
'
--S
'
,
type
=
int
,
default
=
1024
,
help
=
'
Seq len
'
)
parser
.
add_argument
(
"
--S
"
,
type
=
int
,
default
=
1024
,
help
=
"
Seq len
"
)
parser
.
add_argument
(
'
--H
'
,
type
=
int
,
default
=
32
,
help
=
'
Num heads
'
)
parser
.
add_argument
(
"
--H
"
,
type
=
int
,
default
=
32
,
help
=
"
Num heads
"
)
parser
.
add_argument
(
'
--D
'
,
type
=
int
,
default
=
128
,
help
=
'
Head dim
'
)
parser
.
add_argument
(
"
--D
"
,
type
=
int
,
default
=
128
,
help
=
"
Head dim
"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
B
,
args
.
S
,
args
.
H
,
args
.
D
)
main
(
args
.
B
,
args
.
S
,
args
.
H
,
args
.
D
)
examples/linear_attention/example_mamba_chunk_scan.py
View file @
29051439
...
@@ -9,6 +9,7 @@ import itertools
...
@@ -9,6 +9,7 @@ import itertools
def
chunk_scan_triton
(
cb
,
x
,
dt
,
dA_cumsum
,
C
,
states
,
D
):
def
chunk_scan_triton
(
cb
,
x
,
dt
,
dA_cumsum
,
C
,
states
,
D
):
from
mamba_ssm.ops.triton.ssd_chunk_scan
import
_chunk_scan_fwd
from
mamba_ssm.ops.triton.ssd_chunk_scan
import
_chunk_scan_fwd
out
,
_
=
_chunk_scan_fwd
(
cb
,
x
,
dt
,
dA_cumsum
,
C
,
states
,
D
)
out
,
_
=
_chunk_scan_fwd
(
cb
,
x
,
dt
,
dA_cumsum
,
C
,
states
,
D
)
return
out
return
out
...
@@ -43,14 +44,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
...
@@ -43,14 +44,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
dt_segment_sum
=
dA_cumsum
[:,
:,
:,
:,
None
]
-
dA_cumsum
[:,
:,
:,
None
,
:]
dt_segment_sum
=
dA_cumsum
[:,
:,
:,
:,
None
]
-
dA_cumsum
[:,
:,
:,
None
,
:]
decay
=
torch
.
exp
(
dt_segment_sum
)
decay
=
torch
.
exp
(
dt_segment_sum
)
scores_decay
=
cb
*
rearrange
(
decay
,
"b h c l s -> b c h l s"
)
scores_decay
=
cb
*
rearrange
(
decay
,
"b h c l s -> b c h l s"
)
causal_mask
=
torch
.
tril
(
causal_mask
=
torch
.
tril
(
torch
.
ones
(
chunk_size
,
chunk_size
,
device
=
x
.
device
,
dtype
=
bool
),
diagonal
=
0
)
torch
.
ones
(
chunk_size
,
chunk_size
,
device
=
x
.
device
,
dtype
=
bool
),
diagonal
=
0
)
scores_decay
=
scores_decay
.
masked_fill
(
~
causal_mask
,
0
)
scores_decay
=
scores_decay
.
masked_fill
(
~
causal_mask
,
0
)
out
=
torch
.
einsum
(
'bchls,bhcs,bcshp->bclhp'
,
scores_decay
.
to
(
x
.
dtype
),
dt
.
to
(
x
.
dtype
),
out
=
torch
.
einsum
(
rearrange
(
x
,
"b (c s) h p -> b c s h p"
,
c
=
nchunks
))
"bchls,bhcs,bcshp->bclhp"
,
scores_decay
.
to
(
x
.
dtype
),
dt
.
to
(
x
.
dtype
),
rearrange
(
x
,
"b (c s) h p -> b c s h p"
,
c
=
nchunks
)
)
state_decay_out
=
torch
.
exp
(
rearrange
(
dA_cumsum
,
"b h c l -> b c l h 1"
))
state_decay_out
=
torch
.
exp
(
rearrange
(
dA_cumsum
,
"b h c l -> b c l h 1"
))
out_prev
=
torch
.
einsum
(
'bclhn,bchpn->bclhp'
,
rearrange
(
out_prev
=
(
C
,
"b (c l) h n -> b c l h n"
,
c
=
nchunks
),
prev_states
.
to
(
C
.
dtype
))
*
state_decay_out
torch
.
einsum
(
"bclhn,bchpn->bclhp"
,
rearrange
(
C
,
"b (c l) h n -> b c l h n"
,
c
=
nchunks
),
prev_states
.
to
(
C
.
dtype
))
*
state_decay_out
)
out
=
out
+
out_prev
out
=
out
+
out_prev
out
=
rearrange
(
out
,
"b c l h p -> b (c l) h p"
)
out
=
rearrange
(
out
,
"b c l h p -> b (c l) h p"
)
if
D
is
not
None
:
if
D
is
not
None
:
...
@@ -61,12 +63,7 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
...
@@ -61,12 +63,7 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
def
get_configs
():
def
get_configs
():
iter_params
=
dict
(
iter_params
=
dict
(
block_M
=
[
64
,
128
,
256
],
block_N
=
[
32
,
64
],
block_K
=
[
64
,
128
,
256
],
block_Dstate
=
[
128
],
num_stages
=
[
1
,
2
,
3
,
4
,
5
])
block_M
=
[
64
,
128
,
256
],
block_N
=
[
32
,
64
],
block_K
=
[
64
,
128
,
256
],
block_Dstate
=
[
128
],
num_stages
=
[
1
,
2
,
3
,
4
,
5
])
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
...
@@ -77,19 +74,21 @@ def get_configs():
...
@@ -77,19 +74,21 @@ def get_configs():
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
},
},
)
)
def
chunk_scan_fwd
(
batch
,
def
chunk_scan_fwd
(
seqlen
,
batch
,
chunk_size
,
seqlen
,
ngroups
,
chunk_size
,
nheads
,
ngroups
,
headdim
,
nheads
,
dstate
,
headdim
,
block_M
=
64
,
dstate
,
block_N
=
64
,
block_M
=
64
,
block_K
=
64
,
block_N
=
64
,
block_Dstate
=
128
,
block_K
=
64
,
num_stages
=
2
,
block_Dstate
=
128
,
threads
=
128
):
num_stages
=
2
,
threads
=
128
,
):
dtype
=
"float16"
dtype
=
"float16"
accum_dtype
=
"float"
accum_dtype
=
"float"
nchunks
=
T
.
ceildiv
(
seqlen
,
chunk_size
)
nchunks
=
T
.
ceildiv
(
seqlen
,
chunk_size
)
...
@@ -97,20 +96,20 @@ def chunk_scan_fwd(batch,
...
@@ -97,20 +96,20 @@ def chunk_scan_fwd(batch,
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
cb
:
T
.
Tensor
((
batch
,
nchunks
,
ngroups
,
chunk_size
,
chunk_size
),
dtype
),
# type: ignore
cb
:
T
.
Tensor
((
batch
,
nchunks
,
ngroups
,
chunk_size
,
chunk_size
),
dtype
),
# type: ignore
x
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
),
# type: ignore
x
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
),
# type: ignore
dt
:
T
.
Tensor
((
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
# type: ignore
dt
:
T
.
Tensor
((
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
# type: ignore
dA_cumsum
:
T
.
Tensor
((
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
# type: ignore
dA_cumsum
:
T
.
Tensor
((
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
# type: ignore
C
:
T
.
Tensor
((
batch
,
seqlen
,
ngroups
,
dstate
),
dtype
),
# type: ignore
C
:
T
.
Tensor
((
batch
,
seqlen
,
ngroups
,
dstate
),
dtype
),
# type: ignore
prev_states
:
T
.
Tensor
((
batch
,
nchunks
,
nheads
,
headdim
,
dstate
),
dtype
),
# type: ignore
prev_states
:
T
.
Tensor
((
batch
,
nchunks
,
nheads
,
headdim
,
dstate
),
dtype
),
# type: ignore
D
:
T
.
Tensor
((
nheads
),
dtype
),
# type: ignore
D
:
T
.
Tensor
((
nheads
),
dtype
),
# type: ignore
Output
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
)
# type: ignore
Output
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
)
,
# type: ignore
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
nheads
,
T
.
ceildiv
(
chunk_size
,
block_M
)
*
T
.
ceildiv
(
headdim
,
block_N
),
batch
*
nchunks
,
threads
=
threads
)
as
(
nheads
,
bz
,
T
.
ceildiv
(
chunk_size
,
block_M
)
*
T
.
ceildiv
(
headdim
,
block_N
)
,
bx
,
batch
*
nchunks
,
by
,
threads
=
threads
)
as
(
bz
,
bx
,
by
):
):
acc_o
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
acc_o
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
acc_o_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
acc_o_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
cb_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
,
scope
=
"shared.dyn"
)
cb_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
,
scope
=
"shared.dyn"
)
...
@@ -136,27 +135,32 @@ def chunk_scan_fwd(batch,
...
@@ -136,27 +135,32 @@ def chunk_scan_fwd(batch,
m_idx
=
bx
//
T
.
ceildiv
(
headdim
,
block_N
)
m_idx
=
bx
//
T
.
ceildiv
(
headdim
,
block_N
)
n_idx
=
bx
%
T
.
ceildiv
(
headdim
,
block_N
)
n_idx
=
bx
%
T
.
ceildiv
(
headdim
,
block_N
)
T
.
annotate_layout
({
T
.
annotate_layout
(
acc_o_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
acc_o_shared
),
{
cb_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
cb_shared
),
acc_o_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
acc_o_shared
),
x_residual_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
x_residual_shared
)
cb_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
cb_shared
),
})
x_residual_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
x_residual_shared
),
}
)
T
.
no_set_max_nreg
()
T
.
no_set_max_nreg
()
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
m_idx
*
block_M
:(
m_idx
+
1
)
*
block_M
],
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
m_idx
*
block_M
:
(
m_idx
+
1
)
*
block_M
],
dA_cs_m_shared
)
dA_cs_m_shared
)
T
.
copy
(
dA_cs_m_shared
,
dA_cs_m_local
)
T
.
copy
(
dA_cs_m_shared
,
dA_cs_m_local
)
T
.
clear
(
acc_o
)
T
.
clear
(
acc_o
)
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
scale_m_local
[
i
]
=
T
.
exp2
(
dA_cs_m_local
[
i
]
*
p
)
scale_m_local
[
i
]
=
T
.
exp2
(
dA_cs_m_local
[
i
]
*
p
)
T
.
copy
(
T
.
copy
(
C
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
C
[
(
m_idx
+
1
)
*
block_M
,
bz
//
(
nheads
//
ngroups
),
0
:
block_Dstate
],
C_shared
)
batch_idx
,
T
.
copy
(
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
prev_states
[
batch_idx
,
chunk_idx
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
,
bz
//
(
nheads
//
ngroups
),
0
:
block_Dstate
],
prev_state_shared
)
0
:
block_Dstate
,
],
C_shared
,
)
T
.
copy
(
prev_states
[
batch_idx
,
chunk_idx
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
0
:
block_Dstate
],
prev_state_shared
)
T
.
gemm
(
C_shared
,
prev_state_shared
,
acc_o
,
transpose_B
=
True
)
T
.
gemm
(
C_shared
,
prev_state_shared
,
acc_o
,
transpose_B
=
True
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_o
[
i
,
j
]
*=
scale_m_local
[
i
]
acc_o
[
i
,
j
]
*=
scale_m_local
[
i
]
...
@@ -165,34 +169,47 @@ def chunk_scan_fwd(batch,
...
@@ -165,34 +169,47 @@ def chunk_scan_fwd(batch,
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
T
.
copy
(
cb
[
batch_idx
,
chunk_idx
,
bz
//
(
nheads
//
ngroups
),
cb
[
m_idx
*
block_M
:(
m_idx
+
1
)
*
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
batch_idx
,
cb_shared
)
chunk_idx
,
bz
//
(
nheads
//
ngroups
),
m_idx
*
block_M
:
(
m_idx
+
1
)
*
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
,
],
cb_shared
,
)
T
.
copy
(
cb_shared
,
cb_local
)
T
.
copy
(
cb_shared
,
cb_local
)
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
dA_cs_k_shared
)
dA_cs_k_shared
)
T
.
copy
(
dA_cs_k_shared
,
dA_cs_k_local
)
T
.
copy
(
dA_cs_k_shared
,
dA_cs_k_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
cb_local
[
i
,
cb_local
[
i
,
j
]
=
cb_local
[
i
,
j
]
*
T
.
exp2
(
dA_cs_m_local
[
i
]
*
p
-
dA_cs_k_local
[
j
]
*
p
)
j
]
=
cb_local
[
i
,
T
.
copy
(
dt
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
dt_shared
)
j
]
*
T
.
exp2
(
dA_cs_m_local
[
i
]
*
p
-
dA_cs_k_local
[
j
]
*
p
)
T
.
copy
(
dt
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
dt_shared
)
T
.
copy
(
dt_shared
,
dt_local
)
T
.
copy
(
dt_shared
,
dt_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
cb_local
[
i
,
j
]
*=
dt_local
[
j
]
cb_local
[
i
,
j
]
*=
dt_local
[
j
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
cb_local
[
i
,
j
]
=
T
.
if_then_else
(
m_idx
*
block_M
+
i
>=
k
*
block_K
+
j
,
cb_local
[
i
,
j
]
=
T
.
if_then_else
(
m_idx
*
block_M
+
i
>=
k
*
block_K
+
j
,
cb_local
[
i
,
j
],
0
)
cb_local
[
i
,
j
],
0
)
T
.
copy
(
T
.
copy
(
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
x
[
(
k
+
1
)
*
block_K
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
],
x_shared
)
batch_idx
,
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
(
k
+
1
)
*
block_K
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
],
x_shared
,
)
T
.
gemm
(
cb_local
,
x_shared
,
acc_o
)
T
.
gemm
(
cb_local
,
x_shared
,
acc_o
)
D_local
[
0
]
=
D
[
bz
]
D_local
[
0
]
=
D
[
bz
]
T
.
copy
(
T
.
copy
(
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
x
[
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
],
batch_idx
,
x_residual_shared
)
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
],
x_residual_shared
,
)
T
.
copy
(
x_residual_shared
,
x_residual_local
)
T
.
copy
(
x_residual_shared
,
x_residual_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_o
[
i
,
j
]
+=
x_residual_local
[
i
,
j
]
*
D_local
[
0
]
acc_o
[
i
,
j
]
+=
x_residual_local
[
i
,
j
]
*
D_local
[
0
]
...
@@ -200,27 +217,40 @@ def chunk_scan_fwd(batch,
...
@@ -200,27 +217,40 @@ def chunk_scan_fwd(batch,
T
.
copy
(
acc_o
,
acc_o_shared
)
T
.
copy
(
acc_o
,
acc_o_shared
)
T
.
copy
(
T
.
copy
(
acc_o_shared
,
acc_o_shared
,
Output
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
Output
[
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
])
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
],
)
return
main
return
main
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
8
,
help
=
'
batch size
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
8
,
help
=
"
batch size
"
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
80
,
help
=
'
heads
'
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
80
,
help
=
"
heads
"
)
parser
.
add_argument
(
'
--groups
'
,
type
=
int
,
default
=
1
,
help
=
'
groups
'
)
parser
.
add_argument
(
"
--groups
"
,
type
=
int
,
default
=
1
,
help
=
"
groups
"
)
parser
.
add_argument
(
'
--seq_len
'
,
type
=
int
,
default
=
4096
,
help
=
'
sequence length
'
)
parser
.
add_argument
(
"
--seq_len
"
,
type
=
int
,
default
=
4096
,
help
=
"
sequence length
"
)
parser
.
add_argument
(
'
--chunk_size
'
,
type
=
int
,
default
=
256
,
help
=
'
chunk size
'
)
parser
.
add_argument
(
"
--chunk_size
"
,
type
=
int
,
default
=
256
,
help
=
"
chunk size
"
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
64
,
help
=
'
dim
'
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
64
,
help
=
"
dim
"
)
parser
.
add_argument
(
'
--dstate
'
,
type
=
int
,
default
=
128
,
help
=
'
dstate
'
)
parser
.
add_argument
(
"
--dstate
"
,
type
=
int
,
default
=
128
,
help
=
"
dstate
"
)
parser
.
add_argument
(
'
--tune
'
,
action
=
'
store_true
'
,
help
=
'
tune configs
'
)
parser
.
add_argument
(
"
--tune
"
,
action
=
"
store_true
"
,
help
=
"
tune configs
"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
batch
,
heads
,
groups
,
seq_len
,
chunk_size
,
dim
,
dstate
=
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
seq_len
,
args
.
chunk_size
,
args
.
dim
,
args
.
dstate
batch
,
heads
,
groups
,
seq_len
,
chunk_size
,
dim
,
dstate
=
(
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
seq_len
,
args
.
chunk_size
,
args
.
dim
,
args
.
dstate
,
)
total_flops
=
2
*
batch
*
seq_len
*
chunk_size
*
heads
*
dim
*
0.5
+
2
*
batch
*
seq_len
*
heads
*
dim
*
dstate
total_flops
=
2
*
batch
*
seq_len
*
chunk_size
*
heads
*
dim
*
0.5
+
2
*
batch
*
seq_len
*
heads
*
dim
*
dstate
if
(
not
args
.
tune
)
:
if
not
args
.
tune
:
kernel
=
chunk_scan_fwd
(
kernel
=
chunk_scan_fwd
(
batch
,
batch
,
seq_len
,
seq_len
,
...
@@ -234,7 +264,8 @@ if __name__ == "__main__":
...
@@ -234,7 +264,8 @@ if __name__ == "__main__":
block_K
=
64
,
block_K
=
64
,
block_Dstate
=
128
,
block_Dstate
=
128
,
num_stages
=
2
,
num_stages
=
2
,
threads
=
128
)
threads
=
128
,
)
profiler
=
kernel
.
get_profiler
(
tilelang
.
TensorSupplyType
.
Normal
)
profiler
=
kernel
.
get_profiler
(
tilelang
.
TensorSupplyType
.
Normal
)
profiler
.
assert_allclose
(
ref_program
,
rtol
=
0.01
,
atol
=
0.01
)
profiler
.
assert_allclose
(
ref_program
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"All checks pass."
)
print
(
"All checks pass."
)
...
...
examples/linear_attention/example_mamba_chunk_state.py
View file @
29051439
...
@@ -10,6 +10,7 @@ import itertools
...
@@ -10,6 +10,7 @@ import itertools
def
chunk_state_triton
(
B
,
x
,
dt
,
dA_cumsum
):
def
chunk_state_triton
(
B
,
x
,
dt
,
dA_cumsum
):
from
mamba_ssm.ops.triton.ssd_chunk_state
import
_chunk_state_fwd
from
mamba_ssm.ops.triton.ssd_chunk_state
import
_chunk_state_fwd
return
_chunk_state_fwd
(
B
,
x
,
dt
,
dA_cumsum
,
states_in_fp32
=
False
)
return
_chunk_state_fwd
(
B
,
x
,
dt
,
dA_cumsum
,
states_in_fp32
=
False
)
...
@@ -41,46 +42,33 @@ def ref_program(B, x, dt, dA_cumsum):
...
@@ -41,46 +42,33 @@ def ref_program(B, x, dt, dA_cumsum):
x
=
rearrange
(
x
,
"b (c l) h p -> b c l h p"
,
l
=
chunk_size
)
x
=
rearrange
(
x
,
"b (c l) h p -> b c l h p"
,
l
=
chunk_size
)
B
=
rearrange
(
B
,
"b (c l) ... -> b c l ..."
,
l
=
chunk_size
)
B
=
rearrange
(
B
,
"b (c l) ... -> b c l ..."
,
l
=
chunk_size
)
decay_states
=
torch
.
exp
((
dA_cumsum
[:,
:,
:,
-
1
:]
-
dA_cumsum
))
decay_states
=
torch
.
exp
((
dA_cumsum
[:,
:,
:,
-
1
:]
-
dA_cumsum
))
return
torch
.
einsum
(
"bclhn,bhcl,bhcl,bclhp->bchpn"
,
B
.
to
(
x
.
dtype
),
decay_states
.
to
(
x
.
dtype
),
return
torch
.
einsum
(
"bclhn,bhcl,bhcl,bclhp->bchpn"
,
B
.
to
(
x
.
dtype
),
decay_states
.
to
(
x
.
dtype
),
dt
.
to
(
x
.
dtype
),
x
)
dt
.
to
(
x
.
dtype
),
x
)
def
get_configs
():
def
get_configs
():
iter_params
=
dict
(
iter_params
=
dict
(
block_M
=
[
64
,
128
],
block_N
=
[
32
,
64
,
128
],
block_K
=
[
32
,
64
],
num_stages
=
[
1
,
2
,
3
,
4
,
5
])
block_M
=
[
64
,
128
],
block_N
=
[
32
,
64
,
128
],
block_K
=
[
32
,
64
],
num_stages
=
[
1
,
2
,
3
,
4
,
5
])
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
4
])
@
tilelang
.
jit
(
out_idx
=
[
4
])
def
chunk_state_fwd
(
batch
,
def
chunk_state_fwd
(
seqlen
,
batch
,
seqlen
,
chunk_size
,
ngroups
,
nheads
,
headdim
,
dstate
,
block_M
=
64
,
block_N
=
64
,
block_K
=
64
,
num_stages
=
2
,
threads
=
128
chunk_size
,
):
ngroups
,
nheads
,
headdim
,
dstate
,
block_M
=
64
,
block_N
=
64
,
block_K
=
64
,
num_stages
=
2
,
threads
=
128
):
dtype
=
"float16"
dtype
=
"float16"
accum_dtype
=
"float"
accum_dtype
=
"float"
nchunks
=
T
.
ceildiv
(
seqlen
,
chunk_size
)
nchunks
=
T
.
ceildiv
(
seqlen
,
chunk_size
)
p
=
1.44269504
p
=
1.44269504
@
T
.
prim_func
@
T
.
prim_func
def
main
(
B
:
T
.
Tensor
((
batch
,
seqlen
,
ngroups
,
dstate
),
dtype
),
x
:
T
.
Tensor
(
def
main
(
(
batch
,
seqlen
,
nheads
,
headdim
),
dtype
),
dt
:
T
.
Tensor
(
B
:
T
.
Tensor
((
batch
,
seqlen
,
ngroups
,
dstate
),
dtype
),
(
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
dA_cumsum
:
T
.
Tensor
(
x
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
),
(
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
Output
:
T
.
Tensor
(
dt
:
T
.
Tensor
((
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
(
batch
,
nchunks
,
nheads
,
headdim
,
dstate
),
dtype
)):
dA_cumsum
:
T
.
Tensor
((
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
with
T
.
Kernel
(
Output
:
T
.
Tensor
((
batch
,
nchunks
,
nheads
,
headdim
,
dstate
),
dtype
),
nheads
,
):
T
.
ceildiv
(
headdim
,
block_M
)
*
T
.
ceildiv
(
dstate
,
block_N
),
with
T
.
Kernel
(
nheads
,
T
.
ceildiv
(
headdim
,
block_M
)
*
T
.
ceildiv
(
dstate
,
block_N
),
batch
*
nchunks
,
threads
=
threads
)
as
(
bz
,
bx
,
by
):
batch
*
nchunks
,
threads
=
threads
)
as
(
bz
,
bx
,
by
):
x_shared
=
T
.
alloc_shared
((
block_K
,
block_M
),
dtype
)
x_shared
=
T
.
alloc_shared
((
block_K
,
block_M
),
dtype
)
x_local
=
T
.
alloc_fragment
((
block_K
,
block_M
),
dtype
)
x_local
=
T
.
alloc_fragment
((
block_K
,
block_M
),
dtype
)
xt_local
=
T
.
alloc_fragment
((
block_M
,
block_K
),
dtype
)
xt_local
=
T
.
alloc_fragment
((
block_M
,
block_K
),
dtype
)
...
@@ -101,20 +89,24 @@ def chunk_state_fwd(batch,
...
@@ -101,20 +89,24 @@ def chunk_state_fwd(batch,
m_idx
=
bx
//
T
.
ceildiv
(
dstate
,
block_N
)
m_idx
=
bx
//
T
.
ceildiv
(
dstate
,
block_N
)
n_idx
=
bx
%
T
.
ceildiv
(
dstate
,
block_N
)
n_idx
=
bx
%
T
.
ceildiv
(
dstate
,
block_N
)
T
.
annotate_layout
({
T
.
annotate_layout
(
x_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
x_shared
),
{
x_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
x_shared
),
acc_o_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
acc_o_shared
)}
acc_o_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
acc_o_shared
)
)
})
dA_cs_last
[
0
]
=
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
chunk_size
-
1
]
dA_cs_last
[
0
]
=
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
chunk_size
-
1
]
T
.
clear
(
acc_o
)
T
.
clear
(
acc_o
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
T
.
copy
(
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
x
[
(
k
+
1
)
*
block_K
,
bz
,
m_idx
*
block_M
:(
m_idx
+
1
)
*
block_M
],
x_shared
)
batch_idx
,
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
(
k
+
1
)
*
block_K
,
dA_cumsum_shared
)
bz
,
T
.
copy
(
dt
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
dt_shared
)
m_idx
*
block_M
:
(
m_idx
+
1
)
*
block_M
,
],
x_shared
,
)
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
dA_cumsum_shared
)
T
.
copy
(
dt
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
dt_shared
)
T
.
copy
(
dA_cumsum_shared
,
dA_cumsum_local
)
T
.
copy
(
dA_cumsum_shared
,
dA_cumsum_local
)
T
.
copy
(
dt_shared
,
dt_local
)
T
.
copy
(
dt_shared
,
dt_local
)
for
i
in
T
.
Parallel
(
block_K
):
for
i
in
T
.
Parallel
(
block_K
):
...
@@ -123,47 +115,50 @@ def chunk_state_fwd(batch,
...
@@ -123,47 +115,50 @@ def chunk_state_fwd(batch,
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
xt_local
[
i
,
j
]
=
x_local
[
j
,
i
]
*
scale
[
j
]
xt_local
[
i
,
j
]
=
x_local
[
j
,
i
]
*
scale
[
j
]
T
.
copy
(
T
.
copy
(
B
[
batch_idx
,
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
B
[
(
k
+
1
)
*
block_K
,
bz
//
(
nheads
//
ngroups
),
batch_idx
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
],
B_shared
)
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
(
k
+
1
)
*
block_K
,
bz
//
(
nheads
//
ngroups
),
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
],
B_shared
,
)
T
.
gemm
(
xt_local
,
B_shared
,
acc_o
)
T
.
gemm
(
xt_local
,
B_shared
,
acc_o
)
T
.
copy
(
acc_o
,
acc_o_shared
)
T
.
copy
(
acc_o
,
acc_o_shared
)
T
.
copy
(
T
.
copy
(
acc_o_shared
,
acc_o_shared
,
Output
[
batch_idx
,
chunk_idx
,
bz
,
m_idx
*
block_M
:
(
m_idx
+
1
)
*
block_M
,
Output
[
batch_idx
,
chunk_idx
,
bz
,
m_idx
*
block_M
:
(
m_idx
+
1
)
*
block_M
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
],
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
]
)
)
return
main
return
main
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
8
,
help
=
'
batch size
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
8
,
help
=
"
batch size
"
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
80
,
help
=
'
heads
'
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
80
,
help
=
"
heads
"
)
parser
.
add_argument
(
'
--groups
'
,
type
=
int
,
default
=
1
,
help
=
'
groups
'
)
parser
.
add_argument
(
"
--groups
"
,
type
=
int
,
default
=
1
,
help
=
"
groups
"
)
parser
.
add_argument
(
'
--seq_len
'
,
type
=
int
,
default
=
4096
,
help
=
'
sequence length
'
)
parser
.
add_argument
(
"
--seq_len
"
,
type
=
int
,
default
=
4096
,
help
=
"
sequence length
"
)
parser
.
add_argument
(
'
--chunk_size
'
,
type
=
int
,
default
=
256
,
help
=
'
chunk size
'
)
parser
.
add_argument
(
"
--chunk_size
"
,
type
=
int
,
default
=
256
,
help
=
"
chunk size
"
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
64
,
help
=
'
dim
'
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
64
,
help
=
"
dim
"
)
parser
.
add_argument
(
'
--dstate
'
,
type
=
int
,
default
=
128
,
help
=
'
dstate
'
)
parser
.
add_argument
(
"
--dstate
"
,
type
=
int
,
default
=
128
,
help
=
"
dstate
"
)
parser
.
add_argument
(
'
--tune
'
,
action
=
'
store_true
'
,
help
=
'
tune configs
'
)
parser
.
add_argument
(
"
--tune
"
,
action
=
"
store_true
"
,
help
=
"
tune configs
"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
batch
,
heads
,
groups
,
seq_len
,
chunk_size
,
dim
,
dstate
=
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
seq_len
,
args
.
chunk_size
,
args
.
dim
,
args
.
dstate
batch
,
heads
,
groups
,
seq_len
,
chunk_size
,
dim
,
dstate
=
(
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
seq_len
,
args
.
chunk_size
,
args
.
dim
,
args
.
dstate
,
)
total_flops
=
2
*
batch
*
seq_len
*
heads
*
dim
*
dstate
total_flops
=
2
*
batch
*
seq_len
*
heads
*
dim
*
dstate
if
(
not
args
.
tune
)
:
if
not
args
.
tune
:
kernel
=
chunk_state_fwd
(
kernel
=
chunk_state_fwd
(
batch
,
batch
,
seq_len
,
chunk_size
,
groups
,
heads
,
dim
,
dstate
,
block_M
=
64
,
block_N
=
128
,
block_K
=
64
,
num_stages
=
4
,
threads
=
128
seq_len
,
)
chunk_size
,
groups
,
heads
,
dim
,
dstate
,
block_M
=
64
,
block_N
=
128
,
block_K
=
64
,
num_stages
=
4
,
threads
=
128
)
profiler
=
kernel
.
get_profiler
(
tilelang
.
TensorSupplyType
.
Normal
)
profiler
=
kernel
.
get_profiler
(
tilelang
.
TensorSupplyType
.
Normal
)
profiler
.
assert_allclose
(
ref_program
,
rtol
=
0.01
,
atol
=
0.01
)
profiler
.
assert_allclose
(
ref_program
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"All checks pass."
)
print
(
"All checks pass."
)
...
...
examples/linear_attention/example_retention_fwd.py
View file @
29051439
...
@@ -13,13 +13,12 @@ def chunk_retention_fwd_kernel(
...
@@ -13,13 +13,12 @@ def chunk_retention_fwd_kernel(
H
,
H
,
DK
,
DK
,
DV
,
DV
,
dtype
:
str
=
'
float16
'
,
dtype
:
str
=
"
float16
"
,
scale
:
float
=
None
,
scale
:
float
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
scale
is
None
:
if
scale
is
None
:
scale
=
DK
**-
0.5
scale
=
DK
**-
0.5
accum_dtype
=
'
float
'
accum_dtype
=
"
float
"
chunk_size
=
64
chunk_size
=
64
BK
=
BV
=
64
# Set to 128 can be faster, but has some numerical differences with FLA
BK
=
BV
=
64
# Set to 128 can be faster, but has some numerical differences with FLA
...
@@ -30,16 +29,16 @@ def chunk_retention_fwd_kernel(
...
@@ -30,16 +29,16 @@ def chunk_retention_fwd_kernel(
@
T
.
prim_func
@
T
.
prim_func
def
chunk_retention_fwd
(
def
chunk_retention_fwd
(
Q
:
T
.
Tensor
([
B
,
S
,
H
,
DK
],
dtype
),
# type: ignore
Q
:
T
.
Tensor
([
B
,
S
,
H
,
DK
],
dtype
),
# type: ignore
K
:
T
.
Tensor
([
B
,
S
,
H
,
DK
],
dtype
),
# type: ignore
K
:
T
.
Tensor
([
B
,
S
,
H
,
DK
],
dtype
),
# type: ignore
V
:
T
.
Tensor
([
B
,
S
,
H
,
DV
],
dtype
),
# type: ignore
V
:
T
.
Tensor
([
B
,
S
,
H
,
DV
],
dtype
),
# type: ignore
O
:
T
.
Tensor
([
NK
,
B
,
S
,
H
,
DV
],
dtype
),
# type: ignore
O
:
T
.
Tensor
([
NK
,
B
,
S
,
H
,
DV
],
dtype
),
# type: ignore
):
):
with
T
.
Kernel
(
NV
,
NK
,
B
*
H
)
as
(
i_v
,
i_k
,
i_bh
):
with
T
.
Kernel
(
NV
,
NK
,
B
*
H
)
as
(
i_v
,
i_k
,
i_bh
):
i_b
=
i_bh
//
H
i_b
=
i_bh
//
H
i_h
=
i_bh
%
H
i_h
=
i_bh
%
H
log_decay
=
T
.
alloc_var
(
'
float32
'
)
log_decay
=
T
.
alloc_var
(
"
float32
"
)
log_decay
=
T
.
log2
(
1
-
T
.
exp2
(
-
5.
-
1.
*
i_h
))
# Head-specific log decay
log_decay
=
T
.
log2
(
1
-
T
.
exp2
(
-
5.
0
-
1.
0
*
i_h
))
# Head-specific log decay
q
=
T
.
alloc_shared
([
chunk_size
,
BK
],
dtype
)
q
=
T
.
alloc_shared
([
chunk_size
,
BK
],
dtype
)
k
=
T
.
alloc_shared
([
chunk_size
,
BK
],
dtype
)
k
=
T
.
alloc_shared
([
chunk_size
,
BK
],
dtype
)
...
@@ -56,14 +55,12 @@ def chunk_retention_fwd_kernel(
...
@@ -56,14 +55,12 @@ def chunk_retention_fwd_kernel(
for
i
in
T
.
Pipelined
(
0
,
NT
):
for
i
in
T
.
Pipelined
(
0
,
NT
):
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
BK
):
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
BK
):
q
[
row
,
col
]
=
Q
[
i_b
,
i
*
chunk_size
+
row
,
i_h
,
i_k
*
BK
+
col
]
*
scale
q
[
row
,
col
]
=
Q
[
i_b
,
i
*
chunk_size
+
row
,
i_h
,
i_k
*
BK
+
col
]
*
scale
T
.
copy
(
K
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
K
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_k
*
BK
:
(
i_k
+
1
)
*
BK
],
k
)
T
.
copy
(
V
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
v
)
T
.
copy
(
V
[
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
v
)
T
.
gemm
(
q
,
k
,
s
,
clear_accum
=
True
,
transpose_B
=
True
)
T
.
gemm
(
q
,
k
,
s
,
clear_accum
=
True
,
transpose_B
=
True
)
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
chunk_size
):
for
row
,
col
in
T
.
Parallel
(
chunk_size
,
chunk_size
):
s_shared
[
row
,
s_shared
[
row
,
col
]
=
T
.
if_then_else
(
row
>=
col
,
s
[
row
,
col
]
*
T
.
exp2
((
row
-
col
)
*
log_decay
),
0
)
col
]
=
T
.
if_then_else
(
row
>=
col
,
s
[
row
,
col
]
*
T
.
exp2
(
(
row
-
col
)
*
log_decay
),
0
)
T
.
copy
(
h
,
h_shared
)
T
.
copy
(
h
,
h_shared
)
T
.
gemm
(
q
,
h_shared
,
o
,
clear_accum
=
True
)
T
.
gemm
(
q
,
h_shared
,
o
,
clear_accum
=
True
)
...
@@ -75,9 +72,7 @@ def chunk_retention_fwd_kernel(
...
@@ -75,9 +72,7 @@ def chunk_retention_fwd_kernel(
v
[
row
,
col
]
=
v
[
row
,
col
]
*
T
.
exp2
((
chunk_size
-
row
-
1
)
*
log_decay
)
v
[
row
,
col
]
=
v
[
row
,
col
]
*
T
.
exp2
((
chunk_size
-
row
-
1
)
*
log_decay
)
for
row
,
col
in
T
.
Parallel
(
BK
,
BV
):
for
row
,
col
in
T
.
Parallel
(
BK
,
BV
):
h
[
row
,
col
]
=
T
.
exp2
(
chunk_size
*
log_decay
)
*
h
[
row
,
col
]
h
[
row
,
col
]
=
T
.
exp2
(
chunk_size
*
log_decay
)
*
h
[
row
,
col
]
T
.
copy
(
T
.
copy
(
o
,
O
[
i_k
,
i_b
,
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
])
o
,
O
[
i_k
,
i_b
,
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
])
T
.
gemm
(
k
,
v
,
h
,
transpose_A
=
True
)
T
.
gemm
(
k
,
v
,
h
,
transpose_A
=
True
)
return
chunk_retention_fwd
return
chunk_retention_fwd
...
@@ -89,24 +84,24 @@ def postprocess(o):
...
@@ -89,24 +84,24 @@ def postprocess(o):
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--B
'
,
type
=
int
,
default
=
8
,
help
=
'
Batch size
'
)
parser
.
add_argument
(
"
--B
"
,
type
=
int
,
default
=
8
,
help
=
"
Batch size
"
)
parser
.
add_argument
(
'
--S
'
,
type
=
int
,
default
=
4096
,
help
=
'
Seq len
'
)
parser
.
add_argument
(
"
--S
"
,
type
=
int
,
default
=
4096
,
help
=
"
Seq len
"
)
parser
.
add_argument
(
'
--H
'
,
type
=
int
,
default
=
32
,
help
=
'
Num heads
'
)
parser
.
add_argument
(
"
--H
"
,
type
=
int
,
default
=
32
,
help
=
"
Num heads
"
)
parser
.
add_argument
(
'
--D
'
,
type
=
int
,
default
=
128
,
help
=
'
Head dim
'
)
parser
.
add_argument
(
"
--D
"
,
type
=
int
,
default
=
128
,
help
=
"
Head dim
"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
B
,
S
,
H
,
D
=
args
.
B
,
args
.
S
,
args
.
H
,
args
.
D
B
,
S
,
H
,
D
=
args
.
B
,
args
.
S
,
args
.
H
,
args
.
D
total_flops
=
2.0
*
B
*
S
*
S
*
H
*
D
# causal
total_flops
=
2.0
*
B
*
S
*
S
*
H
*
D
# causal
q
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
((
B
,
S
,
H
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
kernel
=
chunk_retention_fwd_kernel
(
B
,
S
,
H
,
D
,
D
)
kernel
=
chunk_retention_fwd_kernel
(
B
,
S
,
H
,
D
,
D
)
t
=
do_bench
(
lambda
:
postprocess
(
kernel
(
q
,
k
,
v
)),
warmup
=
25
,
rep
=
100
)
t
=
do_bench
(
lambda
:
postprocess
(
kernel
(
q
,
k
,
v
)),
warmup
=
25
,
rep
=
100
)
print
(
f
'
Tilelang latency:
{
t
:.
3
f
}
ms
'
)
print
(
f
"
Tilelang latency:
{
t
:.
3
f
}
ms
"
)
print
(
f
'
Tilelang TFLOPs:
{
total_flops
/
t
*
1e-9
}
'
)
print
(
f
"
Tilelang TFLOPs:
{
total_flops
/
t
*
1e-9
}
"
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
main
()
main
()
examples/minference/example_vertical_slash_sparse_attn.py
View file @
29051439
...
@@ -15,12 +15,11 @@ from tilelang.profiler import do_bench
...
@@ -15,12 +15,11 @@ from tilelang.profiler import do_bench
@
tilelang
.
jit
(
out_idx
=
[
3
])
@
tilelang
.
jit
(
out_idx
=
[
3
])
def
_tl_vs_sparse_flashattn
(
batch
,
heads
,
seq_len
,
dim
,
vertical_size
,
slash_size
):
def
_tl_vs_sparse_flashattn
(
batch
,
heads
,
seq_len
,
dim
,
vertical_size
,
slash_size
):
block_M
=
64
block_M
=
64
block_N
=
64
block_N
=
64
num_stages
=
2
num_stages
=
2
threads
=
128
threads
=
128
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
seq_blocks
=
(
seq_len
+
block_M
-
1
)
//
block_M
seq_blocks
=
(
seq_len
+
block_M
-
1
)
//
block_M
...
@@ -30,15 +29,13 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
...
@@ -30,15 +29,13 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
offset_shape
=
count_shape
+
[
slash_size
]
offset_shape
=
count_shape
+
[
slash_size
]
index_shape
=
count_shape
+
[
vertical_size
]
index_shape
=
count_shape
+
[
vertical_size
]
vertical_size_round
,
slash_size_round
=
tilelang
.
next_power_of_2
(
vertical_size_round
,
slash_size_round
=
tilelang
.
next_power_of_2
(
vertical_size
),
tilelang
.
next_power_of_2
(
slash_size
)
vertical_size
),
tilelang
.
next_power_of_2
(
slash_size
)
dtype
=
"float16"
dtype
=
"float16"
accum_dtype
=
"float"
accum_dtype
=
"float"
int_dtype
=
"int32"
int_dtype
=
"int32"
def
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
):
def
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
):
@
T
.
macro
@
T
.
macro
def
Prefetch
(
def
Prefetch
(
K
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
...
@@ -53,32 +50,30 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
...
@@ -53,32 +50,30 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
):
):
with
T
.
attr
(
"default"
,
"async_scope"
,
1
):
with
T
.
attr
(
"default"
,
"async_scope"
,
1
):
for
i
,
j
in
T
.
Parallel
(
block_N
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_N
,
dim
):
K_shared
[
i
,
j
]
=
T
.
if_then_else
(
k
+
i
<
column_count
,
K_shared
[
i
,
j
]
=
T
.
if_then_else
(
k
+
i
<
column_count
,
K
[
bz
,
by
,
column_index
[
k
+
i
],
j
],
0
)
K
[
bz
,
by
,
column_index
[
k
+
i
],
j
],
0
)
with
T
.
attr
(
"default"
,
"async_scope"
,
1
):
with
T
.
attr
(
"default"
,
"async_scope"
,
1
):
for
i
,
j
in
T
.
Parallel
(
block_N
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_N
,
dim
):
V_shared
[
i
,
j
]
=
T
.
if_then_else
(
k
+
i
<
column_count
,
V_shared
[
i
,
j
]
=
T
.
if_then_else
(
k
+
i
<
column_count
,
V
[
bz
,
by
,
column_index
[
k
+
i
],
j
],
0
)
V
[
bz
,
by
,
column_index
[
k
+
i
],
j
],
0
)
T
.
ptx_commit_group
()
T
.
ptx_commit_group
()
@
T
.
macro
@
T
.
macro
def
Compute
(
def
Compute
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
k
:
T
.
int32
,
k
:
T
.
int32
,
column_count
:
T
.
int32
,
column_count
:
T
.
int32
,
Q_shared
:
T
.
SharedBuffer
([
block_M
,
dim
],
dtype
),
Q_shared
:
T
.
SharedBuffer
([
block_M
,
dim
],
dtype
),
K_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
K_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
V_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
V_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
count
:
T
.
int32
,
count
:
T
.
int32
,
):
):
T
.
ptx_wait_group
(
count
)
T
.
ptx_wait_group
(
count
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
...
@@ -108,17 +103,16 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
...
@@ -108,17 +103,16 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
@
T
.
prim_func
@
T
.
prim_func
def
vs_sparse_flashattn_ws
(
def
vs_sparse_flashattn_ws
(
Q
:
T
.
Tensor
(
shape
,
dtype
),
Q
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
V
:
T
.
Tensor
(
shape
,
dtype
),
V
:
T
.
Tensor
(
shape
,
dtype
),
Output
:
T
.
Tensor
(
shape
,
dtype
),
Output
:
T
.
Tensor
(
shape
,
dtype
),
BlockCount
:
T
.
Tensor
(
count_shape
,
int_dtype
),
BlockCount
:
T
.
Tensor
(
count_shape
,
int_dtype
),
BlockOffset
:
T
.
Tensor
(
offset_shape
,
int_dtype
),
BlockOffset
:
T
.
Tensor
(
offset_shape
,
int_dtype
),
ColumnCount
:
T
.
Tensor
(
count_shape
,
int_dtype
),
ColumnCount
:
T
.
Tensor
(
count_shape
,
int_dtype
),
ColumnIndex
:
T
.
Tensor
(
index_shape
,
int_dtype
),
ColumnIndex
:
T
.
Tensor
(
index_shape
,
int_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
256
)
as
(
bc
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
256
)
as
(
bc
,
by
,
bz
):
bx
=
T
.
ceildiv
(
seq_len
,
block_M
)
-
1
-
bc
bx
=
T
.
ceildiv
(
seq_len
,
block_M
)
-
1
-
bc
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
2
,
block_N
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
2
,
block_N
,
dim
],
dtype
)
...
@@ -143,9 +137,11 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
...
@@ -143,9 +137,11 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
T
.
create_list_of_mbarrier
([
128
]
*
9
)
T
.
create_list_of_mbarrier
([
128
]
*
9
)
T
.
annotate_layout
({
T
.
annotate_layout
(
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
{
})
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
}
)
block_count
[
0
]
=
BlockCount
[
bz
,
by
,
bx
]
block_count
[
0
]
=
BlockCount
[
bz
,
by
,
bx
]
column_count
[
0
]
=
ColumnCount
[
bz
,
by
,
bx
]
column_count
[
0
]
=
ColumnCount
[
bz
,
by
,
bx
]
...
@@ -162,15 +158,15 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
...
@@ -162,15 +158,15 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
if
tid
>=
128
:
if
tid
>=
128
:
T
.
annotate_producer_reg_dealloc
()
T
.
annotate_producer_reg_dealloc
()
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
mbarrier_arrive
(
mbarrier
=
8
)
T
.
mbarrier_arrive
(
mbarrier
=
8
)
for
bi
in
T
.
serial
(
block_count
[
0
]):
for
bi
in
T
.
serial
(
block_count
[
0
]):
k
=
block_offset
[
bi
]
k
=
block_offset
[
bi
]
T
.
mbarrier_wait_parity
(
mbarrier
=
bi
%
2
+
4
,
parity
=
(((
bi
&
3
)
>>
1
)
^
1
))
T
.
mbarrier_wait_parity
(
mbarrier
=
bi
%
2
+
4
,
parity
=
(((
bi
&
3
)
>>
1
)
^
1
))
T
.
copy
(
K
[
bz
,
by
,
k
:
k
+
block_N
,
:],
K_shared
[
bi
%
2
,
:,
:])
T
.
copy
(
K
[
bz
,
by
,
k
:
k
+
block_N
,
:],
K_shared
[
bi
%
2
,
:,
:])
T
.
mbarrier_arrive
(
mbarrier
=
bi
%
2
)
T
.
mbarrier_arrive
(
mbarrier
=
bi
%
2
)
T
.
mbarrier_wait_parity
(
mbarrier
=
bi
%
2
+
6
,
parity
=
(((
bi
&
3
)
>>
1
)
^
1
))
T
.
mbarrier_wait_parity
(
mbarrier
=
bi
%
2
+
6
,
parity
=
(((
bi
&
3
)
>>
1
)
^
1
))
T
.
copy
(
V
[
bz
,
by
,
k
:
k
+
block_N
,
:],
V_shared
[
bi
%
2
,
:,
:])
T
.
copy
(
V
[
bz
,
by
,
k
:
k
+
block_N
,
:],
V_shared
[
bi
%
2
,
:,
:])
T
.
mbarrier_arrive
(
mbarrier
=
bi
%
2
+
2
)
T
.
mbarrier_arrive
(
mbarrier
=
bi
%
2
+
2
)
else
:
else
:
T
.
annotate_consumer_reg_alloc
()
T
.
annotate_consumer_reg_alloc
()
...
@@ -181,16 +177,10 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
...
@@ -181,16 +177,10 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
for
bi
in
T
.
serial
(
block_count
[
0
]):
for
bi
in
T
.
serial
(
block_count
[
0
]):
k
=
block_offset
[
bi
]
k
=
block_offset
[
bi
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
+
j
,
0
,
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
-
T
.
infinity
(
acc_s
.
dtype
))
T
.
mbarrier_wait_parity
(
mbarrier
=
bi
%
2
,
parity
=
((
bi
&
3
)
>>
1
))
T
.
mbarrier_wait_parity
(
mbarrier
=
bi
%
2
,
parity
=
((
bi
&
3
)
>>
1
))
T
.
gemm
(
T
.
gemm
(
Q_shared
,
K_shared
[
bi
%
2
,
:,
:],
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
Q_shared
,
K_shared
[
bi
%
2
,
:,
:],
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
mbarrier_arrive
(
mbarrier
=
bi
%
2
+
4
)
T
.
mbarrier_arrive
(
mbarrier
=
bi
%
2
+
4
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
copy
(
scores_max
,
scores_max_prev
)
...
@@ -200,20 +190,15 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
...
@@ -200,20 +190,15 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
=
acc_o
[
i
,
j
]
*
scores_scale
[
i
]
acc_o
[
i
,
j
]
=
acc_o
[
i
,
j
]
*
scores_scale
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
T
.
copy
(
acc_s
,
acc_s_cast
)
T
.
mbarrier_wait_parity
(
mbarrier
=
bi
%
2
+
2
,
parity
=
(((
bi
&
3
)
>>
1
)))
T
.
mbarrier_wait_parity
(
mbarrier
=
bi
%
2
+
2
,
parity
=
((
bi
&
3
)
>>
1
))
T
.
gemm
(
T
.
gemm
(
acc_s_cast
,
V_shared
[
bi
%
2
,
:,
:],
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
acc_s_cast
,
V_shared
[
bi
%
2
,
:,
:],
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
mbarrier_arrive
(
mbarrier
=
bi
%
2
+
6
)
T
.
mbarrier_arrive
(
mbarrier
=
bi
%
2
+
6
)
...
@@ -223,38 +208,85 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
...
@@ -223,38 +208,85 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
if
column_count
[
0
]
!=
0
:
if
column_count
[
0
]
!=
0
:
Prefetch
(
K
,
V
,
K_shared_1
,
V_shared_1
,
column_index
,
column_count
[
0
],
0
,
bz
,
Prefetch
(
K
,
V
,
K_shared_1
,
V_shared_1
,
column_index
,
column_count
[
0
],
0
,
bz
,
by
)
by
)
for
bi
in
T
.
serial
(
T
.
ceildiv
(
column_count
[
0
],
block_N
)
-
1
):
for
bi
in
T
.
serial
(
T
.
ceildiv
(
column_count
[
0
],
block_N
)
-
1
):
k
=
bi
*
block_N
k
=
bi
*
block_N
if
bi
%
2
==
0
:
if
bi
%
2
==
0
:
Prefetch
(
K
,
V
,
K_shared_2
,
V_shared_2
,
column_index
,
Prefetch
(
K
,
V
,
K_shared_2
,
V_shared_2
,
column_index
,
column_count
[
0
],
k
+
block_N
,
bz
,
by
)
column_count
[
0
],
k
+
block_N
,
bz
,
by
)
Compute
(
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
k
,
acc_s
,
column_count
[
0
],
Q_shared
,
K_shared_1
,
V_shared_1
,
acc_s_cast
,
scores_scale
,
scores_sum
,
logsum
,
1
)
acc_o
,
scores_max
,
scores_max_prev
,
k
,
column_count
[
0
],
Q_shared
,
K_shared_1
,
V_shared_1
,
scores_scale
,
scores_sum
,
logsum
,
1
,
)
else
:
else
:
Prefetch
(
K
,
V
,
K_shared_1
,
V_shared_1
,
column_index
,
Prefetch
(
K
,
V
,
K_shared_1
,
V_shared_1
,
column_index
,
column_count
[
0
],
k
+
block_N
,
bz
,
by
)
column_count
[
0
],
k
+
block_N
,
bz
,
by
)
Compute
(
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
k
,
acc_s
,
column_count
[
0
],
Q_shared
,
K_shared_2
,
V_shared_2
,
acc_s_cast
,
scores_scale
,
scores_sum
,
logsum
,
1
)
acc_o
,
scores_max
,
scores_max_prev
,
k
,
column_count
[
0
],
Q_shared
,
K_shared_2
,
V_shared_2
,
scores_scale
,
scores_sum
,
logsum
,
1
,
)
if
T
.
ceildiv
(
column_count
[
0
],
block_N
)
%
2
==
0
:
if
T
.
ceildiv
(
column_count
[
0
],
block_N
)
%
2
==
0
:
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
Compute
(
T
.
ceildiv
(
column_count
[
0
],
block_N
)
*
block_N
-
block_N
,
acc_s
,
column_count
[
0
],
Q_shared
,
K_shared_2
,
V_shared_2
,
scores_scale
,
acc_s_cast
,
scores_sum
,
logsum
,
0
)
acc_o
,
scores_max
,
scores_max_prev
,
T
.
ceildiv
(
column_count
[
0
],
block_N
)
*
block_N
-
block_N
,
column_count
[
0
],
Q_shared
,
K_shared_2
,
V_shared_2
,
scores_scale
,
scores_sum
,
logsum
,
0
,
)
else
:
else
:
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
Compute
(
T
.
ceildiv
(
column_count
[
0
],
block_N
)
*
block_N
-
block_N
,
acc_s
,
column_count
[
0
],
Q_shared
,
K_shared_1
,
V_shared_1
,
scores_scale
,
acc_s_cast
,
scores_sum
,
logsum
,
0
)
acc_o
,
scores_max
,
scores_max_prev
,
T
.
ceildiv
(
column_count
[
0
],
block_N
)
*
block_N
-
block_N
,
column_count
[
0
],
Q_shared
,
K_shared_1
,
V_shared_1
,
scores_scale
,
scores_sum
,
logsum
,
0
,
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
vs_sparse_flashattn_ws
return
vs_sparse_flashattn_ws
...
@@ -470,11 +502,8 @@ def vertical_slash_sparse_attention(
...
@@ -470,11 +502,8 @@ def vertical_slash_sparse_attention(
import
os
import
os
current_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
current_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sources
=
[
sources
=
[
os
.
path
.
join
(
current_dir
,
"ops"
,
"kernels.cpp"
),
os
.
path
.
join
(
current_dir
,
"ops"
,
"vertical_slash_index.cu"
)]
os
.
path
.
join
(
current_dir
,
'ops'
,
'kernels.cpp'
),
ops
=
load
(
name
=
"convert"
,
sources
=
sources
,
verbose
=
False
)
os
.
path
.
join
(
current_dir
,
'ops'
,
'vertical_slash_index.cu'
)
]
ops
=
load
(
name
=
'convert'
,
sources
=
sources
,
verbose
=
False
)
convert_vertical_slash_indexes
=
ops
.
convert_vertical_slash_indexes
convert_vertical_slash_indexes
=
ops
.
convert_vertical_slash_indexes
batch_size
,
num_heads
,
context_size
,
head_dim
=
query
.
shape
batch_size
,
num_heads
,
context_size
,
head_dim
=
query
.
shape
pad
=
(
block_size_M
-
context_size
)
&
(
block_size_M
-
1
)
pad
=
(
block_size_M
-
context_size
)
&
(
block_size_M
-
1
)
...
@@ -485,15 +514,13 @@ def vertical_slash_sparse_attention(
...
@@ -485,15 +514,13 @@ def vertical_slash_sparse_attention(
value
=
torch
.
nn
.
functional
.
pad
(
value
,
[
0
,
0
,
0
,
pad
,
0
,
0
,
0
,
0
])
value
=
torch
.
nn
.
functional
.
pad
(
value
,
[
0
,
0
,
0
,
pad
,
0
,
0
,
0
,
0
])
if
head_dim
not
in
[
16
,
32
,
64
,
128
,
256
,
512
]:
if
head_dim
not
in
[
16
,
32
,
64
,
128
,
256
,
512
]:
target_dim
=
2
**
math
.
ceil
(
math
.
log2
(
head_dim
))
-
head_dim
target_dim
=
2
**
math
.
ceil
(
math
.
log2
(
head_dim
))
-
head_dim
query
=
torch
.
nn
.
functional
.
pad
(
query
,
[
0
,
target_dim
,
0
,
0
,
0
,
0
,
0
,
0
])
query
=
torch
.
nn
.
functional
.
pad
(
query
,
[
0
,
target_dim
,
0
,
0
,
0
,
0
,
0
,
0
])
key
=
torch
.
nn
.
functional
.
pad
(
key
,
[
0
,
target_dim
,
0
,
0
,
0
,
0
,
0
,
0
])
key
=
torch
.
nn
.
functional
.
pad
(
key
,
[
0
,
target_dim
,
0
,
0
,
0
,
0
,
0
,
0
])
value
=
torch
.
nn
.
functional
.
pad
(
value
,
[
0
,
target_dim
,
0
,
0
,
0
,
0
,
0
,
0
])
value
=
torch
.
nn
.
functional
.
pad
(
value
,
[
0
,
target_dim
,
0
,
0
,
0
,
0
,
0
,
0
])
v_idx
=
v_idx
.
to
(
torch
.
int32
).
reshape
((
batch_size
,
num_heads
,
-
1
)).
sort
(
v_idx
=
v_idx
.
to
(
torch
.
int32
).
reshape
((
batch_size
,
num_heads
,
-
1
)).
sort
(
dim
=-
1
,
descending
=
False
)[
0
]
dim
=-
1
,
descending
=
False
)[
0
]
s_idx
=
s_idx
.
to
(
torch
.
int32
).
reshape
((
batch_size
,
num_heads
,
-
1
)).
sort
(
dim
=-
1
,
descending
=
True
)[
0
]
s_idx
=
s_idx
.
to
(
torch
.
int32
).
reshape
((
batch_size
,
num_heads
,
-
1
)).
sort
(
dim
=-
1
,
descending
=
True
)[
0
]
seqlens
=
torch
.
tensor
([
context_size
]
*
query
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
query
.
device
)
seqlens
=
torch
.
tensor
([
context_size
]
*
query
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
query
.
device
)
sm_scale
=
head_dim
**-
0.5
sm_scale
=
head_dim
**-
0.5
...
@@ -506,8 +533,7 @@ def vertical_slash_sparse_attention(
...
@@ -506,8 +533,7 @@ def vertical_slash_sparse_attention(
block_size_N
,
block_size_N
,
)
)
tl_kernel
=
_tl_vs_sparse_flashattn
(
batch_size
,
num_heads
,
context_size
,
head_dim
,
tl_kernel
=
_tl_vs_sparse_flashattn
(
batch_size
,
num_heads
,
context_size
,
head_dim
,
v_idx
.
shape
[
2
],
s_idx
.
shape
[
2
])
v_idx
.
shape
[
2
],
s_idx
.
shape
[
2
])
def
run
(
is_triton
:
bool
=
True
):
def
run
(
is_triton
:
bool
=
True
):
if
is_triton
:
if
is_triton
:
...
@@ -525,8 +551,7 @@ def vertical_slash_sparse_attention(
...
@@ -525,8 +551,7 @@ def vertical_slash_sparse_attention(
block_size_N
,
block_size_N
,
)
)
else
:
else
:
out
=
tl_kernel
(
query
,
key
,
value
,
block_count
,
block_offset
,
column_count
,
out
=
tl_kernel
(
query
,
key
,
value
,
block_count
,
block_offset
,
column_count
,
column_index
)
column_index
)
return
out
[...,
:
context_size
,
:
head_dim
]
return
out
[...,
:
context_size
,
:
head_dim
]
return
run
return
run
...
@@ -536,8 +561,7 @@ def sum_all_diagonal_matrix(mat: torch.tensor):
...
@@ -536,8 +561,7 @@ def sum_all_diagonal_matrix(mat: torch.tensor):
b
,
h
,
n
,
m
=
mat
.
shape
b
,
h
,
n
,
m
=
mat
.
shape
zero_mat
=
torch
.
zeros
((
b
,
h
,
n
,
n
)).
to
(
mat
.
device
)
# Zero matrix used for padding
zero_mat
=
torch
.
zeros
((
b
,
h
,
n
,
n
)).
to
(
mat
.
device
)
# Zero matrix used for padding
mat_padded
=
torch
.
cat
((
zero_mat
,
mat
,
zero_mat
),
-
1
)
# pads the matrix on left and right
mat_padded
=
torch
.
cat
((
zero_mat
,
mat
,
zero_mat
),
-
1
)
# pads the matrix on left and right
mat_strided
=
mat_padded
.
as_strided
(
mat_strided
=
mat_padded
.
as_strided
((
1
,
1
,
n
,
n
+
m
),
(
1
,
n
*
(
2
*
n
+
m
),
2
*
n
+
m
+
1
,
1
))
# Change the strides
(
1
,
1
,
n
,
n
+
m
),
(
1
,
n
*
(
2
*
n
+
m
),
2
*
n
+
m
+
1
,
1
))
# Change the strides
sum_diags
=
torch
.
sum
(
mat_strided
,
2
)
# Sums the resulting matrix's columns
sum_diags
=
torch
.
sum
(
mat_strided
,
2
)
# Sums the resulting matrix's columns
return
sum_diags
[:,
:,
1
:]
return
sum_diags
[:,
:,
1
:]
...
@@ -559,24 +583,23 @@ def main(argv=None):
...
@@ -559,24 +583,23 @@ def main(argv=None):
vertical_size
,
slash_size
=
args
.
vertical_size
,
args
.
slash_size
vertical_size
,
slash_size
=
args
.
vertical_size
,
args
.
slash_size
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
q_len
=
SEQ_LEN
q_len
=
SEQ_LEN
vertical_size
,
slash_size
=
min
(
q_len
,
vertical_size
),
min
(
q_len
,
slash_size
)
vertical_size
,
slash_size
=
min
(
q_len
,
vertical_size
),
min
(
q_len
,
slash_size
)
last_q
=
64
last_q
=
64
qk
=
torch
.
einsum
(
'
bhmk, bhnk -> bhmn
'
,
q
[:,
:,
-
last_q
:,
:],
k
)
qk
=
torch
.
einsum
(
"
bhmk, bhnk -> bhmn
"
,
q
[:,
:,
-
last_q
:,
:],
k
)
arange
=
torch
.
arange
(
last_q
,
device
=
"cuda"
)
arange
=
torch
.
arange
(
last_q
,
device
=
"cuda"
)
qk
[:,
:,
:,
-
last_q
:]
=
torch
.
where
(
arange
[
None
,
None
,
:,
None
]
>=
arange
[
None
,
None
,
None
,
:],
qk
[:,
:,
:,
-
last_q
:]
=
torch
.
where
(
arange
[
None
,
None
,
:,
None
]
>=
arange
[
None
,
None
,
None
,
:],
qk
[:,
:,
:,
-
last_q
:],
-
torch
.
inf
)
qk
[:,
:,
:,
-
last_q
:],
-
torch
.
inf
)
qk
=
torch
.
nn
.
functional
.
softmax
(
qk
,
dim
=-
1
,
dtype
=
torch
.
float32
)
qk
=
torch
.
nn
.
functional
.
softmax
(
qk
,
dim
=-
1
,
dtype
=
torch
.
float32
)
vertical
=
qk
.
sum
(
-
2
,
keepdim
=
True
)
vertical
=
qk
.
sum
(
-
2
,
keepdim
=
True
)
vertical
[...,
:
30
]
=
torch
.
inf
vertical
[...,
:
30
]
=
torch
.
inf
vertical_topk
=
torch
.
topk
(
vertical
,
vertical_size
,
-
1
).
indices
vertical_topk
=
torch
.
topk
(
vertical
,
vertical_size
,
-
1
).
indices
slash
=
sum_all_diagonal_matrix
(
qk
)[...,
:
-
last_q
+
1
]
slash
=
sum_all_diagonal_matrix
(
qk
)[...,
:
-
last_q
+
1
]
slash
[...,
-
30
:]
=
torch
.
inf
slash
[...,
-
30
:]
=
torch
.
inf
slash
=
(
q_len
-
1
)
-
torch
.
topk
(
slash
,
slash_size
,
-
1
).
indices
slash
=
(
q_len
-
1
)
-
torch
.
topk
(
slash
,
slash_size
,
-
1
).
indices
...
...
examples/norm/rms_norm.py
View file @
29051439
...
@@ -45,7 +45,7 @@ def rms_norm(M, N, blk_m):
...
@@ -45,7 +45,7 @@ def rms_norm(M, N, blk_m):
A_local
=
T
.
alloc_fragment
((
blk_m
,
N
),
dtype
)
A_local
=
T
.
alloc_fragment
((
blk_m
,
N
),
dtype
)
A_powsum
=
T
.
alloc_fragment
((
blk_m
,),
dtype
)
A_powsum
=
T
.
alloc_fragment
((
blk_m
,),
dtype
)
T
.
copy
(
A
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:],
A_shared
)
T
.
copy
(
A
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:],
A_shared
)
T
.
copy
(
A_shared
,
A_local
)
T
.
copy
(
A_shared
,
A_local
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
A_pow_local
[
i
,
j
]
=
A_local
[
i
,
j
]
*
A_local
[
i
,
j
]
A_pow_local
[
i
,
j
]
=
A_local
[
i
,
j
]
*
A_local
[
i
,
j
]
...
@@ -54,7 +54,7 @@ def rms_norm(M, N, blk_m):
...
@@ -54,7 +54,7 @@ def rms_norm(M, N, blk_m):
A_powsum
[
i
]
=
T
.
rsqrt
(
A_powsum
[
i
]
/
N
+
1e-12
)
A_powsum
[
i
]
=
T
.
rsqrt
(
A_powsum
[
i
]
/
N
+
1e-12
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
A_local
[
i
,
j
]
*=
A_powsum
[
i
]
A_local
[
i
,
j
]
*=
A_powsum
[
i
]
T
.
copy
(
A_local
,
B
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:])
T
.
copy
(
A_local
,
B
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:])
return
main
return
main
...
...
examples/norm/test_rms_norm.py
View file @
29051439
...
@@ -45,7 +45,7 @@ def rms_norm(M, N, blk_m):
...
@@ -45,7 +45,7 @@ def rms_norm(M, N, blk_m):
A_local
=
T
.
alloc_fragment
((
blk_m
,
N
),
dtype
)
A_local
=
T
.
alloc_fragment
((
blk_m
,
N
),
dtype
)
A_powsum
=
T
.
alloc_fragment
((
blk_m
,),
dtype
)
A_powsum
=
T
.
alloc_fragment
((
blk_m
,),
dtype
)
T
.
copy
(
A
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:],
A_shared
)
T
.
copy
(
A
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:],
A_shared
)
T
.
copy
(
A_shared
,
A_local
)
T
.
copy
(
A_shared
,
A_local
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
A_pow_local
[
i
,
j
]
=
A_local
[
i
,
j
]
*
A_local
[
i
,
j
]
A_pow_local
[
i
,
j
]
=
A_local
[
i
,
j
]
*
A_local
[
i
,
j
]
...
@@ -54,7 +54,7 @@ def rms_norm(M, N, blk_m):
...
@@ -54,7 +54,7 @@ def rms_norm(M, N, blk_m):
A_powsum
[
i
]
=
T
.
rsqrt
(
A_powsum
[
i
]
/
N
+
1e-12
)
A_powsum
[
i
]
=
T
.
rsqrt
(
A_powsum
[
i
]
/
N
+
1e-12
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
A_local
[
i
,
j
]
*=
A_powsum
[
i
]
A_local
[
i
,
j
]
*=
A_powsum
[
i
]
T
.
copy
(
A_local
,
B
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:])
T
.
copy
(
A_local
,
B
[
bx
*
blk_m
:
(
bx
+
1
)
*
blk_m
,
:])
return
main
return
main
...
...
examples/online_softmax/online_softmax.py
View file @
29051439
...
@@ -20,8 +20,8 @@ def softmax_kernel(
...
@@ -20,8 +20,8 @@ def softmax_kernel(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
X
:
T
.
Tensor
([
M
,
N
],
dtype
),
X
:
T
.
Tensor
([
M
,
N
],
dtype
),
Y
:
T
.
Tensor
([
M
,
N
],
dtype
),
Y
:
T
.
Tensor
([
M
,
N
],
dtype
),
):
):
with
T
.
Kernel
(
M
,
threads
=
128
)
as
(
i_m
):
with
T
.
Kernel
(
M
,
threads
=
128
)
as
(
i_m
):
x
=
T
.
alloc_fragment
([
BN
],
dtype
)
x
=
T
.
alloc_fragment
([
BN
],
dtype
)
...
@@ -33,7 +33,7 @@ def softmax_kernel(
...
@@ -33,7 +33,7 @@ def softmax_kernel(
T
.
fill
(
lse
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
lse
,
-
T
.
infinity
(
accum_dtype
))
for
i_n
in
T
.
Pipelined
(
0
,
NN
):
for
i_n
in
T
.
Pipelined
(
0
,
NN
):
T
.
copy
(
X
[
i_m
,
i_n
*
BN
:
(
i_n
+
1
)
*
BN
],
x
)
T
.
copy
(
X
[
i_m
,
i_n
*
BN
:
(
i_n
+
1
)
*
BN
],
x
)
T
.
reduce_max
(
x
,
max_x
,
dim
=
0
,
clear
=
True
)
T
.
reduce_max
(
x
,
max_x
,
dim
=
0
,
clear
=
True
)
...
@@ -45,12 +45,12 @@ def softmax_kernel(
...
@@ -45,12 +45,12 @@ def softmax_kernel(
lse
[
0
]
=
max_x
[
0
]
*
scale
+
T
.
log2
(
T
.
exp2
(
lse
[
0
]
-
max_x
[
0
]
*
scale
)
+
sum_exp_x
[
0
])
lse
[
0
]
=
max_x
[
0
]
*
scale
+
T
.
log2
(
T
.
exp2
(
lse
[
0
]
-
max_x
[
0
]
*
scale
)
+
sum_exp_x
[
0
])
for
i_n
in
T
.
Pipelined
(
0
,
NN
):
for
i_n
in
T
.
Pipelined
(
0
,
NN
):
T
.
copy
(
X
[
i_m
,
i_n
*
BN
:
(
i_n
+
1
)
*
BN
],
x
)
T
.
copy
(
X
[
i_m
,
i_n
*
BN
:
(
i_n
+
1
)
*
BN
],
x
)
for
j
in
T
.
Parallel
(
BN
):
for
j
in
T
.
Parallel
(
BN
):
y
[
j
]
=
T
.
exp2
(
x
[
j
]
*
scale
-
lse
[
0
])
y
[
j
]
=
T
.
exp2
(
x
[
j
]
*
scale
-
lse
[
0
])
T
.
copy
(
y
,
Y
[
i_m
,
i_n
*
BN
:
(
i_n
+
1
)
*
BN
])
T
.
copy
(
y
,
Y
[
i_m
,
i_n
*
BN
:
(
i_n
+
1
)
*
BN
])
return
main
return
main
...
@@ -69,4 +69,4 @@ t1 = do_bench(lambda: X.softmax(dim=1), warmup=25, rep=100)
...
@@ -69,4 +69,4 @@ t1 = do_bench(lambda: X.softmax(dim=1), warmup=25, rep=100)
t2
=
do_bench
(
lambda
:
kernel
(
X
),
warmup
=
25
,
rep
=
100
)
t2
=
do_bench
(
lambda
:
kernel
(
X
),
warmup
=
25
,
rep
=
100
)
print
(
f
"torch latency:
{
t1
:.
3
f
}
ms"
)
print
(
f
"torch latency:
{
t1
:.
3
f
}
ms"
)
print
(
f
"TileLang latency:
{
t2
:.
3
f
}
ms"
)
print
(
f
"TileLang latency:
{
t2
:.
3
f
}
ms"
)
print
(
f
"Speedup:
{
t1
/
t2
:.
3
f
}
x"
)
print
(
f
"Speedup:
{
t1
/
t2
:.
3
f
}
x"
)
examples/plot_layout/fragment_mfma_load_a.py
View file @
29051439
...
@@ -11,10 +11,9 @@ from tilelang.intrinsics.mfma_layout import (
...
@@ -11,10 +11,9 @@ from tilelang.intrinsics.mfma_layout import (
)
)
def
make_mfma_load_base_layout
(
dtype
:
str
=
"float16"
,
def
make_mfma_load_base_layout
(
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
,
dtype
:
str
=
"float16"
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
,
k_dim
:
int
=
16
,
transposed
:
bool
=
False
k_dim
:
int
=
16
,
)
->
T
.
Fragment
:
transposed
:
bool
=
False
)
->
T
.
Fragment
:
"""
"""
Create a layout function for storing MFMA results into a fragment buffer.
Create a layout function for storing MFMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mfma_store_layout` to
This layout is used in conjunction with `inverse_mfma_store_layout` to
...
@@ -72,12 +71,10 @@ def make_mfma_load_base_layout(dtype: str = "float16",
...
@@ -72,12 +71,10 @@ def make_mfma_load_base_layout(dtype: str = "float16",
# so the b matrix expected a transposed basic layout
# so the b matrix expected a transposed basic layout
transform_func
:
Callable
=
None
transform_func
:
Callable
=
None
if
matrix
==
"A"
:
if
matrix
==
"A"
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
j
,
i
)
micro_size_s
,
micro_size_r
=
micro_size_x
,
micro_size_k
micro_size_s
,
micro_size_r
=
micro_size_x
,
micro_size_k
elif
matrix
==
"B"
:
elif
matrix
==
"B"
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
j
,
i
)
micro_size_s
,
micro_size_r
=
micro_size_k
,
micro_size_y
micro_size_s
,
micro_size_r
=
micro_size_k
,
micro_size_y
else
:
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
...
@@ -120,14 +117,11 @@ print(base_layout)
...
@@ -120,14 +117,11 @@ print(base_layout)
plot_layout
(
base_layout
,
name
=
"base_layout"
)
plot_layout
(
base_layout
,
name
=
"base_layout"
)
# warp layout 32x32
# warp layout 32x32
warp_layout
=
base_layout
.
repeat
([
warp_rows
,
warp_cols
],
warp_layout
=
base_layout
.
repeat
([
warp_rows
,
warp_cols
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
print
(
warp_layout
)
print
(
warp_layout
)
plot_layout
(
warp_layout
,
name
=
"warp_layout"
)
plot_layout
(
warp_layout
,
name
=
"warp_layout"
)
# block layout 64x32
# block layout 64x32
block_layout
=
warp_layout
.
repeat
([
block_rows
,
1
],
repeat_on_thread
=
True
,
block_layout
=
warp_layout
.
repeat
([
block_rows
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
block_cols
)
lower_dim_first
=
True
).
replicate
(
block_cols
)
print
(
block_layout
)
print
(
block_layout
)
plot_layout
(
block_layout
,
name
=
"block_layout"
)
plot_layout
(
block_layout
,
name
=
"block_layout"
)
examples/plot_layout/fragment_mma_load_a.py
View file @
29051439
...
@@ -5,9 +5,7 @@ from tvm.tir import IndexMap
...
@@ -5,9 +5,7 @@ from tvm.tir import IndexMap
from
tilelang.intrinsics.utils
import
get_mma_micro_size
from
tilelang.intrinsics.utils
import
get_mma_micro_size
def
make_mma_load_base_layout
(
dtype
:
str
=
"float16"
,
def
make_mma_load_base_layout
(
dtype
:
str
=
"float16"
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
,
transposed
:
bool
=
False
)
->
T
.
Fragment
:
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
,
transposed
:
bool
=
False
)
->
T
.
Fragment
:
"""
"""
Create a layout function for storing MMA results into a fragment buffer.
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
This layout is used in conjunction with `inverse_mma_store_layout` to
...
@@ -36,6 +34,7 @@ def make_mma_load_base_layout(dtype: str = "float16",
...
@@ -36,6 +34,7 @@ def make_mma_load_base_layout(dtype: str = "float16",
shared_16x16_to_mma_32x8_layout_sr_b
,
shared_16x16_to_mma_32x8_layout_sr_b
,
shared_16x32_to_mma_32x16_layout_sr_b
,
shared_16x32_to_mma_32x16_layout_sr_b
,
)
)
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
dtype_bits
=
DataType
(
dtype
).
bits
dtype_bits
=
DataType
(
dtype
).
bits
# s represents spatial axis
# s represents spatial axis
...
@@ -67,12 +66,10 @@ def make_mma_load_base_layout(dtype: str = "float16",
...
@@ -67,12 +66,10 @@ def make_mma_load_base_layout(dtype: str = "float16",
# so the b matrix expected a transposed basic layout
# so the b matrix expected a transposed basic layout
transform_func
:
Callable
=
None
transform_func
:
Callable
=
None
if
matrix
==
"A"
:
if
matrix
==
"A"
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
j
,
i
)
micro_size_s
,
micro_size_r
=
micro_size_x
,
micro_size_k
micro_size_s
,
micro_size_r
=
micro_size_x
,
micro_size_k
elif
matrix
==
"B"
:
elif
matrix
==
"B"
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
j
,
i
)
micro_size_s
,
micro_size_r
=
micro_size_k
,
micro_size_y
micro_size_s
,
micro_size_r
=
micro_size_k
,
micro_size_y
else
:
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
...
...
examples/quickstart.py
View file @
29051439
...
@@ -7,12 +7,11 @@ import tilelang.language as T
...
@@ -7,12 +7,11 @@ import tilelang.language as T
# if not specified, it will be inferred from the input tensors during compile time
# if not specified, it will be inferred from the input tensors during compile time
@
tilelang
.
jit
@
tilelang
.
jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
matmul_relu_kernel
(
def
matmul_relu_kernel
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
...
examples/seer_attention/block_sparse_attn_tilelang.py
View file @
29051439
...
@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
...
@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
@@ -30,15 +27,17 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
...
@@ -30,15 +27,17 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
4
],
pass_configs
=
{
out_idx
=
[
4
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
blocksparse_flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
downsample_len
,
is_causal
):
def
blocksparse_flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
downsample_len
,
is_causal
):
block_M
=
64
block_M
=
64
block_N
=
64
block_N
=
64
num_stages
=
0
num_stages
=
0
threads
=
128
threads
=
128
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
q_shape
=
[
batch
,
heads
,
seq_q
,
dim
]
q_shape
=
[
batch
,
heads
,
seq_q
,
dim
]
kv_shape
=
[
batch
,
heads
,
seq_kv
,
dim
]
kv_shape
=
[
batch
,
heads
,
seq_kv
,
dim
]
block_mask_shape
=
[
batch
,
heads
,
downsample_len
,
downsample_len
]
block_mask_shape
=
[
batch
,
heads
,
downsample_len
,
downsample_len
]
...
@@ -48,16 +47,15 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
...
@@ -48,16 +47,15 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
block_mask_dtype
=
"int8"
block_mask_dtype
=
"int8"
def
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
):
def
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
):
@
T
.
macro
@
T
.
macro
def
Softmax
(
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
@@ -83,19 +81,19 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
...
@@ -83,19 +81,19 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
@
T
.
macro
@
T
.
macro
def
Rescale
(
def
Rescale
(
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
BlockSparseMask
:
T
.
Tensor
(
block_mask_shape
,
block_mask_dtype
),
BlockSparseMask
:
T
.
Tensor
(
block_mask_shape
,
block_mask_dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_q
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_q
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
...
@@ -112,7 +110,7 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
...
@@ -112,7 +110,7 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
block_mask
=
T
.
alloc_local
([
downsample_len
],
block_mask_dtype
)
block_mask
=
T
.
alloc_local
([
downsample_len
],
block_mask_dtype
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
@@ -124,33 +122,25 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
...
@@ -124,33 +122,25 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
if
block_mask
[
k
]
!=
0
:
if
block_mask
[
k
]
!=
0
:
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
if
is_causal
:
past_len
=
seq_kv
-
seq_q
past_len
=
seq_kv
-
seq_q
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
+
past_len
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
bx
*
block_M
+
i
+
past_len
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
else
:
T
.
clear
(
acc_s
)
T
.
clear
(
acc_s
)
T
.
gemm
(
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
Q_shared
,
K_shared
,
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
Rescale
(
acc_o
,
scores_scale
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
main
return
main
...
@@ -165,44 +155,40 @@ def test_topk_sparse_attention():
...
@@ -165,44 +155,40 @@ def test_topk_sparse_attention():
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
# Create inputs
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
# Create sparse mask (downsampled to block level)
# Create sparse mask (downsampled to block level)
downsample_factor
=
BLOCK
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
float16
)
device
=
'cuda'
,
dtype
=
torch
.
float16
)
x_ds
[:,
:,
:,
0
]
=
100
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
# Run tilelang kernel
# Run tilelang kernel
kernel
=
blocksparse_flashattn
(
kernel
=
blocksparse_flashattn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
SEQ_LEN
,
D_HEAD
,
downsample_len
,
is_causal
=
True
)
BATCH
,
N_HEADS
,
SEQ_LEN
,
SEQ_LEN
,
D_HEAD
,
downsample_len
,
is_causal
=
True
)
tilelang_output
=
kernel
(
q
,
k
,
v
,
block_mask
.
to
(
torch
.
int8
))
tilelang_output
=
kernel
(
q
,
k
,
v
,
block_mask
.
to
(
torch
.
int8
))
# Compute reference
# Compute reference
# Expand block mask to full attention matrix
# Expand block mask to full attention matrix
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
))
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
))
full_mask
=
full_mask
[...,
:
SEQ_LEN
,
:
SEQ_LEN
].
bool
()
full_mask
=
full_mask
[...,
:
SEQ_LEN
,
:
SEQ_LEN
].
bool
()
full_mask
=
full_mask
&
torch
.
tril
(
torch
.
ones_like
(
full_mask
))
# Apply causal
full_mask
=
full_mask
&
torch
.
tril
(
torch
.
ones_like
(
full_mask
))
# Apply causal
# PyTorch reference implementation
# PyTorch reference implementation
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
'
-inf
'
))
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
print
(
"ref_output"
,
ref_output
)
print
(
"ref_output"
,
ref_output
)
print
(
"tilelang_output"
,
tilelang_output
)
print
(
"tilelang_output"
,
tilelang_output
)
# Verify accuracy
# Verify accuracy
assert
torch
.
allclose
(
tilelang_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
assert
torch
.
allclose
(
tilelang_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
"TileLang output doesn't match reference"
"TileLang output doesn't match reference"
print
(
"Pass topk sparse attention test with qlen == klen"
)
print
(
"Pass topk sparse attention test with qlen == klen"
)
...
@@ -215,42 +201,40 @@ def test_topk_sparse_attention_qlen_lt_klen():
...
@@ -215,42 +201,40 @@ def test_topk_sparse_attention_qlen_lt_klen():
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
# Create inputs.
# Create inputs.
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
Q_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
Q_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
downsample_factor
=
BLOCK
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
K_LEN
/
downsample_factor
)
# number of blocks along one dimension
downsample_len
=
math
.
ceil
(
K_LEN
/
downsample_factor
)
# number of blocks along one dimension
x_ds
=
torch
.
randn
(
x_ds
=
torch
.
randn
(
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
,
device
=
'cuda'
,
dtype
=
torch
.
float16
)
# Force the first column to be high so that the first block is always selected.
# Force the first column to be high so that the first block is always selected.
x_ds
[:,
:,
:,
0
]
=
100
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
kernel
=
blocksparse_flashattn
(
kernel
=
blocksparse_flashattn
(
BATCH
,
N_HEADS
,
Q_LEN
,
K_LEN
,
D_HEAD
,
downsample_len
,
is_causal
=
True
)
BATCH
,
N_HEADS
,
Q_LEN
,
K_LEN
,
D_HEAD
,
downsample_len
,
is_causal
=
True
)
print
(
kernel
.
get_kernel_source
())
print
(
kernel
.
get_kernel_source
())
tilelang_output
=
kernel
(
q
,
k
,
v
,
block_mask
.
to
(
torch
.
int8
))
tilelang_output
=
kernel
(
q
,
k
,
v
,
block_mask
.
to
(
torch
.
int8
))
past_len
=
K_LEN
-
Q_LEN
past_len
=
K_LEN
-
Q_LEN
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
full_mask_full
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
)).
bool
()
full_mask_full
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
)).
bool
()
full_mask_full
=
full_mask_full
[...,
:
K_LEN
,
:
K_LEN
]
full_mask_full
=
full_mask_full
[...,
:
K_LEN
,
:
K_LEN
]
effective_mask
=
full_mask_full
[...,
past_len
:
K_LEN
,
:]
# shape: (B, H, Q_LEN, K_LEN)
effective_mask
=
full_mask_full
[...,
past_len
:
K_LEN
,
:]
# shape: (B, H, Q_LEN, K_LEN)
i_global
=
torch
.
arange
(
past_len
,
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
1
)
# shape: (Q_LEN, 1)
i_global
=
torch
.
arange
(
past_len
,
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
1
)
# shape: (Q_LEN, 1)
j_global
=
torch
.
arange
(
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
0
)
# shape: (1, K_LEN)
j_global
=
torch
.
arange
(
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
0
)
# shape: (1, K_LEN)
causal_mask
=
(
j_global
<=
i_global
)
# shape: (Q_LEN, K_LEN)
causal_mask
=
j_global
<=
i_global
# shape: (Q_LEN, K_LEN)
final_mask
=
effective_mask
&
causal_mask
# shape: (B, H, Q_LEN, K_LEN)
final_mask
=
effective_mask
&
causal_mask
# shape: (B, H, Q_LEN, K_LEN)
attn
=
attn
.
masked_fill
(
~
final_mask
,
float
(
'
-inf
'
))
attn
=
attn
.
masked_fill
(
~
final_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
print
(
"ref_output"
,
ref_output
)
print
(
"ref_output"
,
ref_output
)
print
(
"tilelang_output"
,
tilelang_output
)
print
(
"tilelang_output"
,
tilelang_output
)
...
...
examples/seer_attention/block_sparse_attn_triton.py
View file @
29051439
...
@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
...
@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
@@ -54,7 +51,6 @@ def _fwd_kernel_inner(
...
@@ -54,7 +51,6 @@ def _fwd_kernel_inner(
BLOCK_M
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
):
mask_val
=
tl
.
load
(
block_mask_ptr
+
k_block_col_idx
*
stride_bmask_n
)
mask_val
=
tl
.
load
(
block_mask_ptr
+
k_block_col_idx
*
stride_bmask_n
)
if
mask_val
==
True
:
if
mask_val
==
True
:
...
@@ -69,7 +65,7 @@ def _fwd_kernel_inner(
...
@@ -69,7 +65,7 @@ def _fwd_kernel_inner(
qk
*=
sm_scale
qk
*=
sm_scale
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
qk
+=
tl
.
where
(
offs_m
[:,
None
]
+
past_len
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
float
(
'
-inf
'
))
qk
+=
tl
.
where
(
offs_m
[:,
None
]
+
past_len
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
float
(
"
-inf
"
))
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
qk
-=
m_ij
[:,
None
]
qk
-=
m_ij
[:,
None
]
...
@@ -149,7 +145,7 @@ def _fwd_kernel(
...
@@ -149,7 +145,7 @@ def _fwd_kernel(
v_ptrs
=
V
+
off_v
v_ptrs
=
V
+
off_v
mask_ptrs
=
block_mask_ptr
+
start_m
*
stride_bmm
mask_ptrs
=
block_mask_ptr
+
start_m
*
stride_bmm
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
'
inf
'
)
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"
inf
"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
...
@@ -185,24 +181,12 @@ def _fwd_kernel(
...
@@ -185,24 +181,12 @@ def _fwd_kernel(
acc
=
acc
*
l_recip
acc
=
acc
*
l_recip
acc
=
acc
.
to
(
Out
.
dtype
.
element_ty
)
acc
=
acc
.
to
(
Out
.
dtype
.
element_ty
)
off_o
=
off_z
*
stride_oz
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
off_o
=
off_z
*
stride_oz
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
None
,
:]
*
stride_od
None
,
:]
*
stride_od
out_ptrs
=
Out
+
off_o
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
N_CTX
)
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
N_CTX
)
def
_forward
(
ctx
,
def
_forward
(
ctx
,
q
,
k
,
v
,
block_sparse_mask
,
sm_scale
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
num_warps
=
None
,
num_stages
=
1
,
out
=
None
):
q
,
k
,
v
,
block_sparse_mask
,
sm_scale
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
num_warps
=
None
,
num_stages
=
1
,
out
=
None
):
assert
q
.
shape
[
-
1
]
==
k
.
shape
[
-
1
]
==
v
.
shape
[
-
1
]
assert
q
.
shape
[
-
1
]
==
k
.
shape
[
-
1
]
==
v
.
shape
[
-
1
]
assert
k
.
shape
[
2
]
==
v
.
shape
[
2
]
assert
k
.
shape
[
2
]
==
v
.
shape
[
2
]
o
=
out
if
out
is
not
None
else
torch
.
empty_like
(
q
).
contiguous
()
o
=
out
if
out
is
not
None
else
torch
.
empty_like
(
q
).
contiguous
()
...
@@ -247,7 +231,6 @@ def _forward(ctx,
...
@@ -247,7 +231,6 @@ def _forward(ctx,
class
_sparse_attention
(
torch
.
autograd
.
Function
):
class
_sparse_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
block_sparse_dense
,
sm_scale
):
def
forward
(
ctx
,
q
,
k
,
v
,
block_sparse_dense
,
sm_scale
):
# shape constraints
# shape constraints
...
@@ -271,9 +254,9 @@ def test_topk_sparse_attention():
...
@@ -271,9 +254,9 @@ def test_topk_sparse_attention():
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
# Create inputs
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
# Create sparse mask (downsampled to block level)
# Create sparse mask (downsampled to block level)
...
@@ -281,9 +264,7 @@ def test_topk_sparse_attention():
...
@@ -281,9 +264,7 @@ def test_topk_sparse_attention():
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
print
(
"downsample_len"
,
downsample_len
)
print
(
"downsample_len"
,
downsample_len
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
x_ds
[:,
:,
:,
0
]
=
100
x_ds
[:,
:,
:,
0
]
=
100
print
(
"x_ds.shape"
,
x_ds
.
shape
)
print
(
"x_ds.shape"
,
x_ds
.
shape
)
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
...
@@ -295,22 +276,21 @@ def test_topk_sparse_attention():
...
@@ -295,22 +276,21 @@ def test_topk_sparse_attention():
# Compute reference
# Compute reference
# Expand block mask to full attention matrix
# Expand block mask to full attention matrix
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
))
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
))
full_mask
=
full_mask
[...,
:
SEQ_LEN
,
:
SEQ_LEN
].
bool
()
full_mask
=
full_mask
[...,
:
SEQ_LEN
,
:
SEQ_LEN
].
bool
()
full_mask
=
full_mask
&
torch
.
tril
(
torch
.
ones_like
(
full_mask
))
# Apply causal
full_mask
=
full_mask
&
torch
.
tril
(
torch
.
ones_like
(
full_mask
))
# Apply causal
# PyTorch reference implementation
# PyTorch reference implementation
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
'
-inf
'
))
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
# print("ref_output", ref_output)
# print("ref_output", ref_output)
# print("triton_output", triton_output)
# print("triton_output", triton_output)
# Verify accuracy
# Verify accuracy
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
"Triton output doesn't match reference"
"Triton output doesn't match reference"
print
(
"Pass topk sparse attention test with qlen == klen"
)
print
(
"Pass topk sparse attention test with qlen == klen"
)
...
@@ -322,16 +302,15 @@ def test_topk_sparse_attention_qlt_kl():
...
@@ -322,16 +302,15 @@ def test_topk_sparse_attention_qlt_kl():
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
# Create inputs.
# Create inputs.
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
Q_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
Q_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
# softmax scale
# softmax scale
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
downsample_factor
=
BLOCK
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
K_LEN
/
downsample_factor
)
# number of blocks along one dimension
downsample_len
=
math
.
ceil
(
K_LEN
/
downsample_factor
)
# number of blocks along one dimension
x_ds
=
torch
.
randn
(
x_ds
=
torch
.
randn
(
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
,
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
# Force the first column to be high so that the first block is always selected.
# Force the first column to be high so that the first block is always selected.
x_ds
[:,
:,
:,
0
]
=
100
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
...
@@ -340,26 +319,25 @@ def test_topk_sparse_attention_qlt_kl():
...
@@ -340,26 +319,25 @@ def test_topk_sparse_attention_qlt_kl():
past_len
=
K_LEN
-
Q_LEN
past_len
=
K_LEN
-
Q_LEN
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
full_mask_full
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
)).
bool
()
full_mask_full
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
)).
bool
()
full_mask_full
=
full_mask_full
[...,
:
K_LEN
,
:
K_LEN
]
full_mask_full
=
full_mask_full
[...,
:
K_LEN
,
:
K_LEN
]
effective_mask
=
full_mask_full
[...,
past_len
:
K_LEN
,
:]
# shape: (B, H, Q_LEN, K_LEN)
effective_mask
=
full_mask_full
[...,
past_len
:
K_LEN
,
:]
# shape: (B, H, Q_LEN, K_LEN)
i_global
=
torch
.
arange
(
past_len
,
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
1
)
# shape: (Q_LEN, 1)
i_global
=
torch
.
arange
(
past_len
,
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
1
)
# shape: (Q_LEN, 1)
j_global
=
torch
.
arange
(
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
0
)
# shape: (1, K_LEN)
j_global
=
torch
.
arange
(
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
0
)
# shape: (1, K_LEN)
causal_mask
=
(
j_global
<=
i_global
)
# shape: (Q_LEN, K_LEN)
causal_mask
=
j_global
<=
i_global
# shape: (Q_LEN, K_LEN)
final_mask
=
effective_mask
&
causal_mask
# shape: (B, H, Q_LEN, K_LEN)
final_mask
=
effective_mask
&
causal_mask
# shape: (B, H, Q_LEN, K_LEN)
attn
=
attn
.
masked_fill
(
~
final_mask
,
float
(
'
-inf
'
))
attn
=
attn
.
masked_fill
(
~
final_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
# Verify accuracy.
# Verify accuracy.
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
"Triton output doesn't match reference when qlen < klen"
"Triton output doesn't match reference when qlen < klen"
print
(
"Pass topk sparse attention test with qlen < klen"
)
print
(
"Pass topk sparse attention test with qlen < klen"
)
...
...
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
View file @
29051439
...
@@ -28,24 +28,22 @@ def matmul_sp(
...
@@ -28,24 +28,22 @@ def matmul_sp(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
8
),
'
uint8
'
),
E
:
T
.
Tensor
((
M
,
K
//
8
),
"
uint8
"
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
8
),
'
uint8
'
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
8
),
"
uint8
"
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
E
:
{
make_cutlass_metadata_layout
(
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
arch
=
"9.0"
,
block_k
=
block_K
),
E
,
mma_dtype
=
"float16"
,
arch
=
"9.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
arch
=
"9.0"
,
block_k
=
block_K
),
E_shared
:
}
make_cutlass_metadata_layout
(
)
E_shared
,
mma_dtype
=
"float16"
,
arch
=
"9.0"
,
block_k
=
block_K
),
})
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
8
],
E_shared
)
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
8
],
E_shared
)
...
@@ -57,7 +55,7 @@ def matmul_sp(
...
@@ -57,7 +55,7 @@ def matmul_sp(
return
main
return
main
def
generate_2_to_4_sparse_tensor
(
shape
,
dtype
=
torch
.
float32
,
device
=
'
cpu
'
):
def
generate_2_to_4_sparse_tensor
(
shape
,
dtype
=
torch
.
float32
,
device
=
"
cpu
"
):
if
shape
[
-
1
]
%
4
!=
0
:
if
shape
[
-
1
]
%
4
!=
0
:
raise
ValueError
(
"Last dimension must be divisible by 4 for 2:4 sparsity."
)
raise
ValueError
(
"Last dimension must be divisible by 4 for 2:4 sparsity."
)
...
@@ -102,9 +100,9 @@ def run_gemm_sp(
...
@@ -102,9 +100,9 @@ def run_gemm_sp(
num_threads
,
num_threads
,
)
)
A
=
generate_2_to_4_sparse_tensor
((
M
,
K
),
dtype
=
torch
.
float16
,
device
=
'
cuda
'
)
A
=
generate_2_to_4_sparse_tensor
((
M
,
K
),
dtype
=
torch
.
float16
,
device
=
"
cuda
"
)
A_sparse
,
E
=
compress_sm90
(
A
,
block_k
=
block_K
,
transposed
=
False
)
A_sparse
,
E
=
compress_sm90
(
A
,
block_k
=
block_K
,
transposed
=
False
)
B
=
torch
.
randn
((
K
,
N
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
B
=
torch
.
randn
((
K
,
N
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
C_sp
=
kernel
(
A_sparse
,
E
,
B
).
half
()
C_sp
=
kernel
(
A_sparse
,
E
,
B
).
half
()
C
=
torch
.
matmul
(
A
,
B
)
C
=
torch
.
matmul
(
A
,
B
)
...
...
examples/topk/example_topk.py
View file @
29051439
...
@@ -26,9 +26,9 @@ def tl_topk(
...
@@ -26,9 +26,9 @@ def tl_topk(
@
T
.
prim_func
@
T
.
prim_func
def
topk_kernel
(
def
topk_kernel
(
logits
:
T
.
Tensor
([
M
,
N
],
dtype
),
logits
:
T
.
Tensor
([
M
,
N
],
dtype
),
topk_gates
:
T
.
Tensor
([
M
,
topk
],
dtype
),
topk_gates
:
T
.
Tensor
([
M
,
topk
],
dtype
),
topk_indices
:
T
.
Tensor
([
M
,
topk
],
"int32"
),
topk_indices
:
T
.
Tensor
([
M
,
topk
],
"int32"
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
blk_m
),
threads
=
threads
)
as
bx
:
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
blk_m
),
threads
=
threads
)
as
bx
:
logits_frag
=
T
.
alloc_fragment
([
blk_m
,
N
],
dtype
=
dtype
)
logits_frag
=
T
.
alloc_fragment
([
blk_m
,
N
],
dtype
=
dtype
)
...
@@ -43,15 +43,12 @@ def tl_topk(
...
@@ -43,15 +43,12 @@ def tl_topk(
T
.
reduce_max
(
logits_frag
,
max_val
,
dim
=
1
,
clear
=
True
)
T
.
reduce_max
(
logits_frag
,
max_val
,
dim
=
1
,
clear
=
True
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
expand_max_idx
[
i
,
j
]
=
T
.
if_then_else
(
max_val
[
i
]
==
logits_frag
[
i
,
j
],
j
,
expand_max_idx
[
i
,
j
]
=
T
.
if_then_else
(
max_val
[
i
]
==
logits_frag
[
i
,
j
],
j
,
expand_max_idx
[
i
,
j
])
expand_max_idx
[
i
,
j
])
T
.
reduce_max
(
expand_max_idx
,
max_idx
,
dim
=
1
,
clear
=
True
)
T
.
reduce_max
(
expand_max_idx
,
max_idx
,
dim
=
1
,
clear
=
True
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
logits_frag
[
i
,
j
]
=
T
.
if_then_else
(
max_val
[
i
]
==
logits_frag
[
i
,
j
],
-
10000.0
,
logits_frag
[
i
,
j
])
logits_frag
[
i
,
j
]
=
T
.
if_then_else
(
max_val
[
i
]
==
logits_frag
[
i
,
j
],
-
10000.0
,
logits_frag
[
i
,
j
])
for
i
in
T
.
Parallel
(
blk_m
):
for
i
in
T
.
Parallel
(
blk_m
):
topk_gates
[
bx
*
blk_m
+
i
,
k
]
=
max_val
[
i
]
topk_gates
[
bx
*
blk_m
+
i
,
k
]
=
max_val
[
i
]
...
@@ -61,7 +58,6 @@ def tl_topk(
...
@@ -61,7 +58,6 @@ def tl_topk(
def
ref_program
(
logits
,
top_k
):
def
ref_program
(
logits
,
top_k
):
top_k_gates
,
top_k_indices
=
logits
.
topk
(
top_k
,
dim
=
1
)
top_k_gates
,
top_k_indices
=
logits
.
topk
(
top_k
,
dim
=
1
)
return
top_k_gates
,
top_k_indices
.
to
(
torch
.
int32
)
return
top_k_gates
,
top_k_indices
.
to
(
torch
.
int32
)
...
...
examples/visual_layout_inference/visual_layout_inference.py
View file @
29051439
...
@@ -7,15 +7,15 @@ import tilelang.language as T
...
@@ -7,15 +7,15 @@ import tilelang.language as T
out_idx
=
[
-
1
],
out_idx
=
[
-
1
],
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_LAYOUT_VISUALIZATION_ENABLE
:
True
,
tilelang
.
PassConfigKey
.
TL_LAYOUT_VISUALIZATION_ENABLE
:
True
,
tilelang
.
PassConfigKey
.
TL_LAYOUT_VISUALIZATION_FORMATS
:
"svg"
tilelang
.
PassConfigKey
.
TL_LAYOUT_VISUALIZATION_FORMATS
:
"svg"
,
})
},
)
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
gemm
(
def
gemm
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
@@ -49,12 +49,12 @@ def main():
...
@@ -49,12 +49,12 @@ def main():
print
(
"All check passed."
)
print
(
"All check passed."
)
# print the layout visualization result and save figures to ./tmp.
# print the layout visualization result and save figures to ./tmp.
'''
"""
C_local inferenced layout:
C_local inferenced layout:
Shape: [32, 32] -> [8]
Shape: [32, 32] -> [8]
Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2
Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2
Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2]
Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2]
'''
"""
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment