Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
...@@ -17,7 +17,7 @@ def is_pow_of_2(n): ...@@ -17,7 +17,7 @@ def is_pow_of_2(n):
def hadamard(b, n, dtype): def hadamard(b, n, dtype):
assert is_pow_of_2(n), "n must be a power of 2" assert is_pow_of_2(n), "n must be a power of 2"
assert 2 <= n <= 32768, "n must be in [2, 32768]" assert 2 <= n <= 32768, "n must be in [2, 32768]"
elem_size = {'float32': 4, 'float16': 2, 'bfloat16': 2}[dtype] elem_size = {"float32": 4, "float16": 2, "bfloat16": 2}[dtype]
logN = int(math.log2(n)) logN = int(math.log2(n))
threads = [0, 1, 1, 1, 2, 4, 8, 16, 32, 32, 128, 256, 256, 256, 256, 256][logN] threads = [0, 1, 1, 1, 2, 4, 8, 16, 32, 32, 128, 256, 256, 256, 256, 256][logN]
...@@ -40,23 +40,21 @@ def hadamard(b, n, dtype): ...@@ -40,23 +40,21 @@ def hadamard(b, n, dtype):
# print(f'{exchange_round=}') # print(f'{exchange_round=}')
@T.macro @T.macro
def warp_shfl(local: T.Tensor((thread_elem,), dtype), buf: T.Tensor((thread_elem,), dtype), def warp_shfl(local: T.Tensor((thread_elem,), dtype), buf: T.Tensor((thread_elem,), dtype), round: int):
round: int):
tx = T.get_thread_binding(0) tx = T.get_thread_binding(0)
for i in T.serial(round): for i in T.serial(round):
tx_stride = 1 << i tx_stride = 1 << i
another_tx = tx ^ tx_stride another_tx = tx ^ tx_stride
sign = ( sign = (tx >> i) & 1 # get i-th lowest bit of tx, which determines the operation type for shared[tx, :]
tx >> i
) & 1 # get i-th lowest bit of tx, which determines the operation type for shared[tx, :]
for j in T.Pipelined(thread_elem, num_stages=1): for j in T.Pipelined(thread_elem, num_stages=1):
buf[j] = T.tvm_warp_shuffle( buf[j] = T.tvm_warp_shuffle(
0xffffffff, # mask of all threads 0xFFFFFFFF, # mask of all threads
local[j], local[j],
another_tx % warp_size, another_tx % warp_size,
warp_size, warp_size,
warp_size) warp_size,
)
local[j] = T.if_then_else(sign == 0, local[j] + buf[j], buf[j] - local[j]) local[j] = T.if_then_else(sign == 0, local[j] + buf[j], buf[j] - local[j])
@T.prim_func @T.prim_func
...@@ -78,10 +76,8 @@ def hadamard(b, n, dtype): ...@@ -78,10 +76,8 @@ def hadamard(b, n, dtype):
for j in T.serial(chunknum): for j in T.serial(chunknum):
chunkbase = j * chunksize chunkbase = j * chunksize
for k in T.serial(chunksize // 2): for k in T.serial(chunksize // 2):
local[chunkbase + local[chunkbase + k] = local[chunkbase + k] + local[chunkbase + k + chunksize // 2]
k] = local[chunkbase + k] + local[chunkbase + k + chunksize // 2] local[chunkbase + k + chunksize // 2] = local[chunkbase + k] - 2 * local[chunkbase + k + chunksize // 2]
local[chunkbase + k + chunksize //
2] = local[chunkbase + k] - 2 * local[chunkbase + k + chunksize // 2]
# 3. Hadamard inside warp, n<=512 # 3. Hadamard inside warp, n<=512
# In warp level, we rely on warp shuffle to exchange data inside each warp, without using shared memory # In warp level, we rely on warp shuffle to exchange data inside each warp, without using shared memory
...@@ -131,28 +127,27 @@ def ref_program(x: torch.Tensor): ...@@ -131,28 +127,27 @@ def ref_program(x: torch.Tensor):
assert x.ndim == 2 assert x.ndim == 2
dim = x.shape[-1] dim = x.shape[-1]
assert is_pow_of_2(dim) assert is_pow_of_2(dim)
return F.linear( return F.linear(x, torch.tensor(scipy.linalg.hadamard(dim, dtype=float), dtype=x.dtype, device=x.device))
x, torch.tensor(scipy.linalg.hadamard(dim, dtype=float), dtype=x.dtype, device=x.device))
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=64, help='Batch size') parser.add_argument("--batch", type=int, default=64, help="Batch size")
parser.add_argument('--dim', type=int, default=32768, help='Dimension') parser.add_argument("--dim", type=int, default=32768, help="Dimension")
args = parser.parse_args() args = parser.parse_args()
B, D = args.batch, args.dim B, D = args.batch, args.dim
x = torch.randn((B, D), device='cuda') x = torch.randn((B, D), device="cuda")
kernel = hadamard(B, D, 'float32') kernel = hadamard(B, D, "float32")
y = kernel(x) y = kernel(x)
y_ref = ref_program(x) y_ref = ref_program(x)
torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2) torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2)
print('All tests passed.') print("All tests passed.")
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
latency = profiler.do_bench(warmup=100) latency = profiler.do_bench(warmup=100)
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
if __name__ == '__main__': if __name__ == "__main__":
main() main()
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
"source": [ "source": [
"import sys\n", "import sys\n",
"from pathlib import Path\n", "from pathlib import Path\n",
"\n",
"sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n", "sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n",
"import tilelang\n", "import tilelang\n",
"import torch\n", "import torch\n",
...@@ -61,7 +62,7 @@ ...@@ -61,7 +62,7 @@
" out_dtype: T.dtype = T.float32,\n", " out_dtype: T.dtype = T.float32,\n",
" block_M: int = 128,\n", " block_M: int = 128,\n",
" block_N: int = 128,\n", " block_N: int = 128,\n",
" block_K: int = 32\n", " block_K: int = 32,\n",
"):\n", "):\n",
" M, K = A.shape\n", " M, K = A.shape\n",
" K, N = B.shape\n", " K, N = B.shape\n",
...@@ -94,8 +95,8 @@ ...@@ -94,8 +95,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n",
"C = gemm(A, B)\n", "C = gemm(A, B)\n",
"\n", "\n",
"# check output is correct\n", "# check output is correct\n",
...@@ -118,8 +119,8 @@ ...@@ -118,8 +119,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
"B = torch.randn(512, 1024, dtype=torch.float16, device='cuda')\n", "B = torch.randn(512, 1024, dtype=torch.float16, device=\"cuda\")\n",
"C = gemm(A, B, block_M=64, block_N=64)" "C = gemm(A, B, block_M=64, block_N=64)"
] ]
}, },
...@@ -218,8 +219,8 @@ ...@@ -218,8 +219,8 @@
"source": [ "source": [
"@tilelang.lazy_jit\n", "@tilelang.lazy_jit\n",
"def gemm_dyn_K(\n", "def gemm_dyn_K(\n",
" A: T.Tensor[[int, T.dyn['K']], T.float16], # noqa: F821\n", " A: T.Tensor[[int, T.dyn[\"K\"]], T.float16], # noqa: F821\n",
" B: T.Tensor[[T.dyn['K'], int], T.float16], # noqa: F821\n", " B: T.Tensor[[T.dyn[\"K\"], int], T.float16], # noqa: F821\n",
"):\n", "):\n",
" M, K = A.shape\n", " M, K = A.shape\n",
" K, N = B.shape\n", " K, N = B.shape\n",
...@@ -265,8 +266,8 @@ ...@@ -265,8 +266,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n",
"C = gemm_dyn_K(A, B)\n", "C = gemm_dyn_K(A, B)\n",
"C_ref = (A @ B).float()\n", "C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...@@ -295,18 +296,17 @@ ...@@ -295,18 +296,17 @@
"source": [ "source": [
"from typing import Any\n", "from typing import Any\n",
"\n", "\n",
"\n",
"@tilelang.lazy_jit\n", "@tilelang.lazy_jit\n",
"def as_contingious(\n", "def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):\n",
" A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]\n",
"):\n",
" M, N = A.shape\n", " M, N = A.shape\n",
" B = T.empty((M, N), A.dtype)\n", " B = T.empty((M, N), A.dtype)\n",
" block_M = 128\n", " block_M = 128\n",
" block_N = 128\n", " block_N = 128\n",
" with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
" T.copy(\n", " T.copy(\n",
" A[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N],\n", " A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n",
" B[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N]\n", " B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n",
" )\n", " )\n",
" return B" " return B"
] ]
...@@ -318,7 +318,7 @@ ...@@ -318,7 +318,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"A = torch.randn(1024, 1024, device='cuda')\n", "A = torch.randn(1024, 1024, device=\"cuda\")\n",
"B = as_contingious(A[::2, ::2])\n", "B = as_contingious(A[::2, ::2])\n",
"B_ref = A[::2, ::2].contiguous()\n", "B_ref = A[::2, ::2].contiguous()\n",
"torch.testing.assert_close(B, B_ref)" "torch.testing.assert_close(B, B_ref)"
...@@ -370,8 +370,8 @@ ...@@ -370,8 +370,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n",
"C = gemm_ptr(A, B, 1024, 256, 512)\n", "C = gemm_ptr(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n", "C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...@@ -416,8 +416,8 @@ ...@@ -416,8 +416,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n",
"C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n", "C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n", "C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...@@ -496,18 +496,20 @@ ...@@ -496,18 +496,20 @@
"source": [ "source": [
"from itertools import product\n", "from itertools import product\n",
"\n", "\n",
"\n",
"def get_configs():\n", "def get_configs():\n",
" return [\n", " return [\n",
" {\n", " {\n",
" 'A': T.Tensor((1024, 1024), T.float32),\n", " \"A\": T.Tensor((1024, 1024), T.float32),\n",
" 'B': T.Tensor((1024, 1024), T.float32),\n", " \"B\": T.Tensor((1024, 1024), T.float32),\n",
" 'block_M': block_M,\n", " \"block_M\": block_M,\n",
" 'block_N': block_N,\n", " \"block_N\": block_N,\n",
" 'block_K': block_K,\n", " \"block_K\": block_K,\n",
" }\n", " }\n",
" for block_M, block_N, block_K in product([32, 64], repeat=3)\n", " for block_M, block_N, block_K in product([32, 64], repeat=3)\n",
" ]\n", " ]\n",
"\n", "\n",
"\n",
"gemm.par_compile(get_configs())" "gemm.par_compile(get_configs())"
] ]
}, },
...@@ -579,7 +581,8 @@ ...@@ -579,7 +581,8 @@
"source": [ "source": [
"@T.macro\n", "@T.macro\n",
"def macro_with_ref(x: T.Ref):\n", "def macro_with_ref(x: T.Ref):\n",
" x = 1 # noqa: F841\n", " x = 1 # noqa: F841\n",
"\n",
"\n", "\n",
"@T.prim_func\n", "@T.prim_func\n",
"def foo(x: T.Tensor((2,))):\n", "def foo(x: T.Tensor((2,))):\n",
...@@ -591,6 +594,7 @@ ...@@ -591,6 +594,7 @@
" idx = T.alloc_var(T.int32, 0)\n", " idx = T.alloc_var(T.int32, 0)\n",
" macro_with_ref(x[idx])\n", " macro_with_ref(x[idx])\n",
"\n", "\n",
"\n",
"foo" "foo"
] ]
}, },
...@@ -616,7 +620,7 @@ ...@@ -616,7 +620,7 @@
" A: T.Tensor[[T.dyn], Any],\n", " A: T.Tensor[[T.dyn], Any],\n",
" fn,\n", " fn,\n",
"):\n", "):\n",
" N, = A.shape\n", " (N,) = A.shape\n",
" B = T.empty((N,), dtype=A.dtype)\n", " B = T.empty((N,), dtype=A.dtype)\n",
" block_N = 128\n", " block_N = 128\n",
" with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n", " with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n",
...@@ -624,6 +628,8 @@ ...@@ -624,6 +628,8 @@
" idx = bx * block_N + i\n", " idx = bx * block_N + i\n",
" B[idx] = fn(A[idx])\n", " B[idx] = fn(A[idx])\n",
" return B\n", " return B\n",
"\n",
"\n",
"@T.macro\n", "@T.macro\n",
"def add_one(x):\n", "def add_one(x):\n",
" return x + 1" " return x + 1"
...@@ -636,7 +642,7 @@ ...@@ -636,7 +642,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"A = torch.randn(1024, device='cuda')\n", "A = torch.randn(1024, device=\"cuda\")\n",
"B = element_wise(A, add_one)\n", "B = element_wise(A, add_one)\n",
"B_ref = A + 1\n", "B_ref = A + 1\n",
"torch.testing.assert_close(B, B_ref)" "torch.testing.assert_close(B, B_ref)"
...@@ -670,10 +676,11 @@ ...@@ -670,10 +676,11 @@
" var = var * 3 + 1\n", " var = var * 3 + 1\n",
" n31(x * 3 + 1, var)\n", " n31(x * 3 + 1, var)\n",
"\n", "\n",
"\n",
"@tilelang.lazy_jit\n", "@tilelang.lazy_jit\n",
"def foo(A: T.Tensor[[1], T.int32], n: int):\n", "def foo(A: T.Tensor[[1], T.int32], n: int):\n",
" with T.Kernel(1) as _:\n", " with T.Kernel(1) as _:\n",
" n31(n, A[0])\n" " n31(n, A[0])"
] ]
}, },
{ {
...@@ -694,7 +701,7 @@ ...@@ -694,7 +701,7 @@
} }
], ],
"source": [ "source": [
"A = torch.tensor([100], dtype=torch.int32, device='cuda')\n", "A = torch.tensor([100], dtype=torch.int32, device=\"cuda\")\n",
"foo(A, 5)\n", "foo(A, 5)\n",
"A" "A"
] ]
...@@ -745,12 +752,15 @@ ...@@ -745,12 +752,15 @@
"def sincos(x):\n", "def sincos(x):\n",
" return T.sin(x), T.cos(x)\n", " return T.sin(x), T.cos(x)\n",
"\n", "\n",
"\n",
"@T.prim_func\n", "@T.prim_func\n",
"def foo():\n", "def foo():\n",
" with T.Kernel(32) as x:\n", " with T.Kernel(32) as x:\n",
" s, c = sincos(x)\n", " s, c = sincos(x)\n",
" a = s + c # noqa: F841\n", " a = s + c # noqa: F841\n",
" b = s - c # noqa: F841\n", " b = s - c # noqa: F841\n",
"\n",
"\n",
"foo" "foo"
] ]
} }
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
"source": [ "source": [
"import sys\n", "import sys\n",
"from pathlib import Path\n", "from pathlib import Path\n",
"\n",
"sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n", "sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n",
"import tilelang\n", "import tilelang\n",
"import torch\n", "import torch\n",
...@@ -61,7 +62,7 @@ ...@@ -61,7 +62,7 @@
" out_dtype: T.dtype = T.float32,\n", " out_dtype: T.dtype = T.float32,\n",
" block_M: int = 128,\n", " block_M: int = 128,\n",
" block_N: int = 128,\n", " block_N: int = 128,\n",
" block_K: int = 32\n", " block_K: int = 32,\n",
"):\n", "):\n",
" M, K = A.shape\n", " M, K = A.shape\n",
" K, N = B.shape\n", " K, N = B.shape\n",
...@@ -94,8 +95,8 @@ ...@@ -94,8 +95,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n",
"C = gemm(A, B)\n", "C = gemm(A, B)\n",
"\n", "\n",
"# check output is correct\n", "# check output is correct\n",
...@@ -118,8 +119,8 @@ ...@@ -118,8 +119,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
"B = torch.randn(512, 1024, dtype=torch.float16, device='cuda')\n", "B = torch.randn(512, 1024, dtype=torch.float16, device=\"cuda\")\n",
"C = gemm(A, B, block_M=64, block_N=64)" "C = gemm(A, B, block_M=64, block_N=64)"
] ]
}, },
...@@ -218,8 +219,8 @@ ...@@ -218,8 +219,8 @@
"source": [ "source": [
"@tilelang.lazy_jit\n", "@tilelang.lazy_jit\n",
"def gemm_dyn_K(\n", "def gemm_dyn_K(\n",
" A: T.Tensor[[int, T.dyn['K']], T.float16], # noqa: F821\n", " A: T.Tensor[[int, T.dyn[\"K\"]], T.float16], # noqa: F821\n",
" B: T.Tensor[[T.dyn['K'], int], T.float16], # noqa: F821\n", " B: T.Tensor[[T.dyn[\"K\"], int], T.float16], # noqa: F821\n",
"):\n", "):\n",
" M, K = A.shape\n", " M, K = A.shape\n",
" K, N = B.shape\n", " K, N = B.shape\n",
...@@ -265,8 +266,8 @@ ...@@ -265,8 +266,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n",
"C = gemm_dyn_K(A, B)\n", "C = gemm_dyn_K(A, B)\n",
"C_ref = (A @ B).float()\n", "C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...@@ -295,18 +296,17 @@ ...@@ -295,18 +296,17 @@
"source": [ "source": [
"from typing import Any\n", "from typing import Any\n",
"\n", "\n",
"\n",
"@tilelang.lazy_jit\n", "@tilelang.lazy_jit\n",
"def as_contingious(\n", "def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):\n",
" A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]\n",
"):\n",
" M, N = A.shape\n", " M, N = A.shape\n",
" B = T.empty((M, N), A.dtype)\n", " B = T.empty((M, N), A.dtype)\n",
" block_M = 128\n", " block_M = 128\n",
" block_N = 128\n", " block_N = 128\n",
" with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
" T.copy(\n", " T.copy(\n",
" A[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N],\n", " A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n",
" B[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N]\n", " B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n",
" )\n", " )\n",
" return B" " return B"
] ]
...@@ -318,7 +318,7 @@ ...@@ -318,7 +318,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"A = torch.randn(1024, 1024, device='cuda')\n", "A = torch.randn(1024, 1024, device=\"cuda\")\n",
"B = as_contingious(A[::2, ::2])\n", "B = as_contingious(A[::2, ::2])\n",
"B_ref = A[::2, ::2].contiguous()\n", "B_ref = A[::2, ::2].contiguous()\n",
"torch.testing.assert_close(B, B_ref)" "torch.testing.assert_close(B, B_ref)"
...@@ -370,8 +370,8 @@ ...@@ -370,8 +370,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n",
"C = gemm_ptr(A, B, 1024, 256, 512)\n", "C = gemm_ptr(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n", "C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...@@ -416,8 +416,8 @@ ...@@ -416,8 +416,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n",
"C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n", "C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n", "C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
...@@ -496,18 +496,20 @@ ...@@ -496,18 +496,20 @@
"source": [ "source": [
"from itertools import product\n", "from itertools import product\n",
"\n", "\n",
"\n",
"def get_configs():\n", "def get_configs():\n",
" return [\n", " return [\n",
" {\n", " {\n",
" 'A': T.Tensor((1024, 1024), T.float32),\n", " \"A\": T.Tensor((1024, 1024), T.float32),\n",
" 'B': T.Tensor((1024, 1024), T.float32),\n", " \"B\": T.Tensor((1024, 1024), T.float32),\n",
" 'block_M': block_M,\n", " \"block_M\": block_M,\n",
" 'block_N': block_N,\n", " \"block_N\": block_N,\n",
" 'block_K': block_K,\n", " \"block_K\": block_K,\n",
" }\n", " }\n",
" for block_M, block_N, block_K in product([32, 64], repeat=3)\n", " for block_M, block_N, block_K in product([32, 64], repeat=3)\n",
" ]\n", " ]\n",
"\n", "\n",
"\n",
"gemm.par_compile(get_configs())" "gemm.par_compile(get_configs())"
] ]
}, },
...@@ -579,7 +581,8 @@ ...@@ -579,7 +581,8 @@
"source": [ "source": [
"@T.macro\n", "@T.macro\n",
"def macro_with_ref(x: T.Ref):\n", "def macro_with_ref(x: T.Ref):\n",
" x = 1 # noqa: F841\n", " x = 1 # noqa: F841\n",
"\n",
"\n", "\n",
"@T.prim_func\n", "@T.prim_func\n",
"def foo(x: T.Tensor((2,))):\n", "def foo(x: T.Tensor((2,))):\n",
...@@ -591,6 +594,7 @@ ...@@ -591,6 +594,7 @@
" idx = T.alloc_var(T.int32, 0)\n", " idx = T.alloc_var(T.int32, 0)\n",
" macro_with_ref(x[idx])\n", " macro_with_ref(x[idx])\n",
"\n", "\n",
"\n",
"foo" "foo"
] ]
}, },
...@@ -616,7 +620,7 @@ ...@@ -616,7 +620,7 @@
" A: T.Tensor[[T.dyn], Any],\n", " A: T.Tensor[[T.dyn], Any],\n",
" fn,\n", " fn,\n",
"):\n", "):\n",
" N, = A.shape\n", " (N,) = A.shape\n",
" B = T.empty((N,), dtype=A.dtype)\n", " B = T.empty((N,), dtype=A.dtype)\n",
" block_N = 128\n", " block_N = 128\n",
" with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n", " with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n",
...@@ -624,6 +628,8 @@ ...@@ -624,6 +628,8 @@
" idx = bx * block_N + i\n", " idx = bx * block_N + i\n",
" B[idx] = fn(A[idx])\n", " B[idx] = fn(A[idx])\n",
" return B\n", " return B\n",
"\n",
"\n",
"@T.macro\n", "@T.macro\n",
"def add_one(x):\n", "def add_one(x):\n",
" return x + 1" " return x + 1"
...@@ -636,7 +642,7 @@ ...@@ -636,7 +642,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"A = torch.randn(1024, device='cuda')\n", "A = torch.randn(1024, device=\"cuda\")\n",
"B = element_wise(A, add_one)\n", "B = element_wise(A, add_one)\n",
"B_ref = A + 1\n", "B_ref = A + 1\n",
"torch.testing.assert_close(B, B_ref)" "torch.testing.assert_close(B, B_ref)"
...@@ -670,10 +676,11 @@ ...@@ -670,10 +676,11 @@
" var = var * 3 + 1\n", " var = var * 3 + 1\n",
" n31(x * 3 + 1, var)\n", " n31(x * 3 + 1, var)\n",
"\n", "\n",
"\n",
"@tilelang.lazy_jit\n", "@tilelang.lazy_jit\n",
"def foo(A: T.Tensor[[1], T.int32], n: int):\n", "def foo(A: T.Tensor[[1], T.int32], n: int):\n",
" with T.Kernel(1) as _:\n", " with T.Kernel(1) as _:\n",
" n31(n, A[0])\n" " n31(n, A[0])"
] ]
}, },
{ {
...@@ -694,7 +701,7 @@ ...@@ -694,7 +701,7 @@
} }
], ],
"source": [ "source": [
"A = torch.tensor([100], dtype=torch.int32, device='cuda')\n", "A = torch.tensor([100], dtype=torch.int32, device=\"cuda\")\n",
"foo(A, 5)\n", "foo(A, 5)\n",
"A" "A"
] ]
...@@ -745,12 +752,15 @@ ...@@ -745,12 +752,15 @@
"def sincos(x):\n", "def sincos(x):\n",
" return T.sin(x), T.cos(x)\n", " return T.sin(x), T.cos(x)\n",
"\n", "\n",
"\n",
"@T.prim_func\n", "@T.prim_func\n",
"def foo():\n", "def foo():\n",
" with T.Kernel(32) as x:\n", " with T.Kernel(32) as x:\n",
" s, c = sincos(x)\n", " s, c = sincos(x)\n",
" a = s + c # noqa: F841\n", " a = s + c # noqa: F841\n",
" b = s - c # noqa: F841\n", " b = s - c # noqa: F841\n",
"\n",
"\n",
"foo" "foo"
] ]
} }
......
...@@ -13,20 +13,20 @@ from typing import Optional, Tuple ...@@ -13,20 +13,20 @@ from typing import Optional, Tuple
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) }
)
def tl_fused_chunk_bwd_kernel( def tl_fused_chunk_bwd_kernel(
B, B,
S, S,
H, H,
DK, DK,
DV, DV,
dtype: str = 'float16', dtype: str = "float16",
scale: float = None, scale: float = None,
) -> torch.Tensor: ) -> torch.Tensor:
if scale is None: if scale is None:
scale = DK**-0.5 scale = DK**-0.5
accum_dtype = 'float' accum_dtype = "float"
chunk_size = 64 chunk_size = 64
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
...@@ -37,13 +37,13 @@ def tl_fused_chunk_bwd_kernel( ...@@ -37,13 +37,13 @@ def tl_fused_chunk_bwd_kernel(
@T.prim_func @T.prim_func
def fused_chunk_linear_attn_bwd( def fused_chunk_linear_attn_bwd(
Q: T.Tensor([B, S, H, DK], dtype), # type: ignore Q: T.Tensor([B, S, H, DK], dtype), # type: ignore
K: T.Tensor([B, S, H, DK], dtype), # type: ignore K: T.Tensor([B, S, H, DK], dtype), # type: ignore
V: T.Tensor([B, S, H, DV], dtype), # type: ignore V: T.Tensor([B, S, H, DV], dtype), # type: ignore
dO: T.Tensor([B, S, H, DV], dtype), # type: ignore dO: T.Tensor([B, S, H, DV], dtype), # type: ignore
dQ: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore dQ: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore
dK: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore dK: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore
dV: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore dV: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore
): ):
with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh):
i_b = i_bh // H i_b = i_bh // H
...@@ -66,11 +66,13 @@ def tl_fused_chunk_bwd_kernel( ...@@ -66,11 +66,13 @@ def tl_fused_chunk_bwd_kernel(
dh = T.alloc_fragment([BK, BV], accum_dtype) dh = T.alloc_fragment([BK, BV], accum_dtype)
dh_shared = T.alloc_shared([BK, BV], dtype) dh_shared = T.alloc_shared([BK, BV], dtype)
T.annotate_layout({ T.annotate_layout(
dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), {
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), dq_shared: tilelang.layout.make_swizzled_layout(dq_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared) dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
}) dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
}
)
T.use_swizzle(10) T.use_swizzle(10)
T.clear(h) T.clear(h)
...@@ -78,10 +80,9 @@ def tl_fused_chunk_bwd_kernel( ...@@ -78,10 +80,9 @@ def tl_fused_chunk_bwd_kernel(
# Calculate dQ # Calculate dQ
for i in T.Pipelined(0, NT): for i in T.Pipelined(0, NT):
T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) T.copy(K[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k)
T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v) T.copy(V[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v)
T.copy(dO[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], T.copy(dO[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], do)
do)
T.gemm(do, v, ds, transpose_B=True, clear_accum=True) T.gemm(do, v, ds, transpose_B=True, clear_accum=True)
for row, col in T.Parallel(chunk_size, chunk_size): for row, col in T.Parallel(chunk_size, chunk_size):
...@@ -94,29 +95,19 @@ def tl_fused_chunk_bwd_kernel( ...@@ -94,29 +95,19 @@ def tl_fused_chunk_bwd_kernel(
for row, col in T.Parallel(chunk_size, BK): for row, col in T.Parallel(chunk_size, BK):
dq[row, col] *= scale dq[row, col] *= scale
T.copy(dq, dq_shared) T.copy(dq, dq_shared)
T.atomic_add( T.atomic_add(dQ[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], dq_shared)
dQ[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK],
dq_shared)
# Calculate dK, dV (reversely) # Calculate dK, dV (reversely)
for i in T.Pipelined(1, NT + 1): for i in T.Pipelined(1, NT + 1):
start = NT - i start = NT - i
for row, col in T.Parallel(chunk_size, BK): for row, col in T.Parallel(chunk_size, BK):
q[row, col] = Q[i_b, start * chunk_size + row, i_h, i_k * BK + col] * scale q[row, col] = Q[i_b, start * chunk_size + row, i_h, i_k * BK + col] * scale
T.copy( T.copy(K[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k)
K[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, T.copy(V[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v)
i_k * BK:(i_k + 1) * BK], k) T.copy(dO[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], do)
T.copy(
V[i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
i_v * BV:(i_v + 1) * BV], v)
T.copy(
dO[i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
i_v * BV:(i_v + 1) * BV], do)
# Calculate dk # Calculate dk
T.gemm( T.gemm(v, do, ds, transpose_B=True, clear_accum=True) # ds here actually means `s`, but we simply reuse the buffer `ds`
v, do, ds, transpose_B=True, clear_accum=True
) # ds here actually means `s`, but we simply reuse the buffer `ds`
for row, col in T.Parallel(chunk_size, chunk_size): for row, col in T.Parallel(chunk_size, chunk_size):
ds_shared[row, col] = T.if_then_else(row <= col, ds[row, col], 0) ds_shared[row, col] = T.if_then_else(row <= col, ds[row, col], 0)
T.gemm(ds_shared, q, dk, clear_accum=True) T.gemm(ds_shared, q, dk, clear_accum=True)
...@@ -134,13 +125,9 @@ def tl_fused_chunk_bwd_kernel( ...@@ -134,13 +125,9 @@ def tl_fused_chunk_bwd_kernel(
T.gemm(q, do, dh, transpose_A=True) T.gemm(q, do, dh, transpose_A=True)
T.copy(dk, dk_shared) T.copy(dk, dk_shared)
T.atomic_add( T.atomic_add(dK[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], dk_shared)
dK[i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
i_k * BK:(i_k + 1) * BK], dk_shared)
T.copy(dv, dv_shared) T.copy(dv, dv_shared)
T.atomic_add( T.atomic_add(dV[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], dv_shared)
dV[i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
i_v * BV:(i_v + 1) * BV], dv_shared)
return fused_chunk_linear_attn_bwd return fused_chunk_linear_attn_bwd
...@@ -155,34 +142,31 @@ def tl_fused_chunk_bwd(Q, K, V, dO): ...@@ -155,34 +142,31 @@ def tl_fused_chunk_bwd(Q, K, V, dO):
return dQ.to(torch.float16), dK.to(torch.float16), dV.to(torch.float16) return dQ.to(torch.float16), dK.to(torch.float16), dV.to(torch.float16)
def ref_program(q: torch.Tensor, def ref_program(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]:
k: torch.Tensor,
v: torch.Tensor,
scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]:
q, k, v = q.float(), k.float(), v.float() q, k, v = q.float(), k.float(), v.float()
if scale is None: if scale is None:
scale = q.shape[-1]**-0.5 scale = q.shape[-1] ** -0.5
chunk_size = 64 chunk_size = 64
q = rearrange(q, 'b (n c) h d -> b h n c d', c=chunk_size) * scale q = rearrange(q, "b (n c) h d -> b h n c d", c=chunk_size) * scale
k = rearrange(k, 'b (n c) h d -> b h n c d', c=chunk_size) k = rearrange(k, "b (n c) h d -> b h n c d", c=chunk_size)
v = rearrange(v, 'b (n c) h d -> b h n c d', c=chunk_size) v = rearrange(v, "b (n c) h d -> b h n c d", c=chunk_size)
kv = k.transpose(-1, -2) @ v kv = k.transpose(-1, -2) @ v
kv = kv.cumsum(2) kv = kv.cumsum(2)
h = kv[:, :, -1, :, :] h = kv[:, :, -1, :, :]
kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
inter = q @ kv inter = q @ kv
intra = ((q @ k.transpose(-1, -2)).masked_fill_( intra = (
torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), (q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0)
0)) @ v ) @ v
o = inter + intra o = inter + intra
return rearrange(o, 'b h n c d -> b (n c) h d'), h return rearrange(o, "b h n c d -> b (n c) h d"), h
def main(B=1, S=1024, H=16, D=128): def main(B=1, S=1024, H=16, D=128):
q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True)
k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True)
v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True)
do = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) do = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16)
# qk norm is necessary for linear attn # qk norm is necessary for linear attn
q = l2norm_fwd(q)[0].requires_grad_(True) q = l2norm_fwd(q)[0].requires_grad_(True)
...@@ -193,30 +177,27 @@ def main(B=1, S=1024, H=16, D=128): ...@@ -193,30 +177,27 @@ def main(B=1, S=1024, H=16, D=128):
o_ref, _ = ref_program(q, k, v) o_ref, _ = ref_program(q, k, v)
o_ref.backward(do, retain_graph=True) o_ref.backward(do, retain_graph=True)
assert torch.allclose( assert torch.allclose(dq, q.grad, atol=1e-2, rtol=1e-2), f"dq max err: {(dq - q.grad).abs().max()}"
dq, q.grad, atol=1e-2, rtol=1e-2), f'dq max err: {(dq - q.grad).abs().max()}' assert torch.allclose(dk, k.grad, atol=1e-2, rtol=1e-2), f"dk max err: {(dk - k.grad).abs().max()}"
assert torch.allclose( assert torch.allclose(dv, v.grad, atol=1e-2, rtol=1e-2), f"dv max err: {(dv - v.grad).abs().max()}"
dk, k.grad, atol=1e-2, rtol=1e-2), f'dk max err: {(dk - k.grad).abs().max()}' print("Passed all tests!✅")
assert torch.allclose(
dv, v.grad, atol=1e-2, rtol=1e-2), f'dv max err: {(dv - v.grad).abs().max()}'
print('Passed all tests!✅')
# Benchmark # Benchmark
q.grad = k.grad = v.grad = None q.grad = k.grad = v.grad = None
o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)
t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), backend='cupti') t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), backend="cupti")
t2 = do_bench(lambda: tl_fused_chunk_bwd(q, k, v, do), backend='cupti') t2 = do_bench(lambda: tl_fused_chunk_bwd(q, k, v, do), backend="cupti")
print(f'Triton latency: {t1:.3f} ms') print(f"Triton latency: {t1:.3f} ms")
print(f'TileLang latency: {t2:.3f} ms') print(f"TileLang latency: {t2:.3f} ms")
print(f'Speedup: {t1/t2:.3f}x') print(f"Speedup: {t1 / t2:.3f}x")
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--B', type=int, default=8, help='Batch size') parser.add_argument("--B", type=int, default=8, help="Batch size")
parser.add_argument('--S', type=int, default=1024, help='Seq len') parser.add_argument("--S", type=int, default=1024, help="Seq len")
parser.add_argument('--H', type=int, default=32, help='Num heads') parser.add_argument("--H", type=int, default=32, help="Num heads")
parser.add_argument('--D', type=int, default=128, help='Head dim') parser.add_argument("--D", type=int, default=128, help="Head dim")
args = parser.parse_args() args = parser.parse_args()
main(args.B, args.S, args.H, args.D) main(args.B, args.S, args.H, args.D)
...@@ -14,20 +14,20 @@ from typing import Optional, Tuple ...@@ -14,20 +14,20 @@ from typing import Optional, Tuple
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
def tl_fused_chunk_fwd_kernel( def tl_fused_chunk_fwd_kernel(
B, B,
S, S,
H, H,
DK, DK,
DV, DV,
dtype: str = 'float16', dtype: str = "float16",
scale: float = None, scale: float = None,
) -> torch.Tensor: ) -> torch.Tensor:
if scale is None: if scale is None:
scale = DK**-0.5 scale = DK**-0.5
accum_dtype = 'float' accum_dtype = "float"
chunk_size = 64 chunk_size = 64
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
...@@ -38,11 +38,12 @@ def tl_fused_chunk_fwd_kernel( ...@@ -38,11 +38,12 @@ def tl_fused_chunk_fwd_kernel(
@T.prim_func @T.prim_func
def fused_chunk_linear_attn_fwd( def fused_chunk_linear_attn_fwd(
Q: T.Tensor([B, S, H, DK], dtype), # type: ignore Q: T.Tensor([B, S, H, DK], dtype), # type: ignore
K: T.Tensor([B, S, H, DK], dtype), # type: ignore K: T.Tensor([B, S, H, DK], dtype), # type: ignore
V: T.Tensor([B, S, H, DV], dtype), # type: ignore V: T.Tensor([B, S, H, DV], dtype), # type: ignore
O: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore O: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore
final_state: T.Tensor([B, H, DK, DV], accum_dtype)): # type: ignore final_state: T.Tensor([B, H, DK, DV], accum_dtype),
): # type: ignore
with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh):
i_b = i_bh // H i_b = i_bh // H
i_h = i_bh % H i_h = i_bh % H
...@@ -65,8 +66,8 @@ def tl_fused_chunk_fwd_kernel( ...@@ -65,8 +66,8 @@ def tl_fused_chunk_fwd_kernel(
for i in T.Pipelined(0, NT): for i in T.Pipelined(0, NT):
for row, col in T.Parallel(chunk_size, BK): for row, col in T.Parallel(chunk_size, BK):
q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale
T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) T.copy(K[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k)
T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v) T.copy(V[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v)
T.gemm(q, k, s, clear_accum=True, transpose_B=True) T.gemm(q, k, s, clear_accum=True, transpose_B=True)
for row, col in T.Parallel(chunk_size, chunk_size): for row, col in T.Parallel(chunk_size, chunk_size):
...@@ -77,12 +78,10 @@ def tl_fused_chunk_fwd_kernel( ...@@ -77,12 +78,10 @@ def tl_fused_chunk_fwd_kernel(
T.gemm(k, v, h, transpose_A=True) T.gemm(k, v, h, transpose_A=True)
T.gemm(q, h_shared, o) T.gemm(q, h_shared, o)
T.copy(o, o_shared) T.copy(o, o_shared)
T.atomic_add( T.atomic_add(O[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], o_shared)
O[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV],
o_shared)
# Output final state # Output final state
T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV]) T.copy(h, final_state[i_b, i_h, i_k * BK : (i_k + 1) * BK, i_v * BV : (i_v + 1) * BV])
return fused_chunk_linear_attn_fwd return fused_chunk_linear_attn_fwd
...@@ -91,38 +90,35 @@ def tl_fused_chunk_fwd(q, k, v): ...@@ -91,38 +90,35 @@ def tl_fused_chunk_fwd(q, k, v):
B, S, H, D = q.shape B, S, H, D = q.shape
kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D) kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D)
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
o = torch.zeros((B, S, H, D), device='cuda', dtype=torch.float32) o = torch.zeros((B, S, H, D), device="cuda", dtype=torch.float32)
h = kernel(q, k, v, o) h = kernel(q, k, v, o)
return o, h return o, h
def ref_program(q: torch.Tensor, def ref_program(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]:
k: torch.Tensor,
v: torch.Tensor,
scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]:
q, k, v = q.float(), k.float(), v.float() q, k, v = q.float(), k.float(), v.float()
if scale is None: if scale is None:
scale = q.shape[-1]**-0.5 scale = q.shape[-1] ** -0.5
chunk_size = 64 chunk_size = 64
q = rearrange(q, 'b (n c) h d -> b h n c d', c=chunk_size) * scale q = rearrange(q, "b (n c) h d -> b h n c d", c=chunk_size) * scale
k = rearrange(k, 'b (n c) h d -> b h n c d', c=chunk_size) k = rearrange(k, "b (n c) h d -> b h n c d", c=chunk_size)
v = rearrange(v, 'b (n c) h d -> b h n c d', c=chunk_size) v = rearrange(v, "b (n c) h d -> b h n c d", c=chunk_size)
kv = k.transpose(-1, -2) @ v kv = k.transpose(-1, -2) @ v
kv = kv.cumsum(2) kv = kv.cumsum(2)
h = kv[:, :, -1, :, :] h = kv[:, :, -1, :, :]
kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
inter = q @ kv inter = q @ kv
intra = ((q @ k.transpose(-1, -2)).masked_fill_( intra = (
torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), (q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0)
0)) @ v ) @ v
o = inter + intra o = inter + intra
return rearrange(o, 'b h n c d -> b (n c) h d'), h return rearrange(o, "b h n c d -> b (n c) h d"), h
def main(B=1, S=512, H=16, D=128): def main(B=1, S=512, H=16, D=128):
q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16)
k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16)
v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16)
# qk norm is necessary for linear attn # qk norm is necessary for linear attn
q, _ = l2norm_fwd(q) q, _ = l2norm_fwd(q)
...@@ -131,25 +127,23 @@ def main(B=1, S=512, H=16, D=128): ...@@ -131,25 +127,23 @@ def main(B=1, S=512, H=16, D=128):
o, h = tl_fused_chunk_fwd(q, k, v) o, h = tl_fused_chunk_fwd(q, k, v)
o_ref, h_ref = ref_program(q, k, v) o_ref, h_ref = ref_program(q, k, v)
assert torch.allclose(o, o_ref, atol=1e-2, rtol=1e-2), f'o max err: {(o - o_ref).abs().max()}' assert torch.allclose(o, o_ref, atol=1e-2, rtol=1e-2), f"o max err: {(o - o_ref).abs().max()}"
assert torch.allclose(h, h_ref, atol=1e-2, rtol=1e-2), f'h max err: {(h - h_ref).abs().max()}' assert torch.allclose(h, h_ref, atol=1e-2, rtol=1e-2), f"h max err: {(h - h_ref).abs().max()}"
print('Passed all tests!✅') print("Passed all tests!✅")
t1 = do_bench( t1 = do_bench(lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False), backend="cupti")
lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False), t2 = do_bench(lambda: tl_fused_chunk_fwd(q, k, v), backend="cupti")
backend='cupti') print(f"Triton latency: {t1:.3f} ms")
t2 = do_bench(lambda: tl_fused_chunk_fwd(q, k, v), backend='cupti') print(f"TileLang latency: {t2:.3f} ms")
print(f'Triton latency: {t1:.3f} ms') print(f"Speedup: {t1 / t2:.3f}x")
print(f'TileLang latency: {t2:.3f} ms')
print(f'Speedup: {t1/t2:.3f}x')
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--B', type=int, default=8, help='Batch size') parser.add_argument("--B", type=int, default=8, help="Batch size")
parser.add_argument('--S', type=int, default=1024, help='Seq len') parser.add_argument("--S", type=int, default=1024, help="Seq len")
parser.add_argument('--H', type=int, default=32, help='Num heads') parser.add_argument("--H", type=int, default=32, help="Num heads")
parser.add_argument('--D', type=int, default=128, help='Head dim') parser.add_argument("--D", type=int, default=128, help="Head dim")
args = parser.parse_args() args = parser.parse_args()
main(args.B, args.S, args.H, args.D) main(args.B, args.S, args.H, args.D)
...@@ -9,6 +9,7 @@ import itertools ...@@ -9,6 +9,7 @@ import itertools
def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D): def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D):
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd
out, _ = _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D) out, _ = _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D)
return out return out
...@@ -43,14 +44,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): ...@@ -43,14 +44,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :]
decay = torch.exp(dt_segment_sum) decay = torch.exp(dt_segment_sum)
scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s") scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s")
causal_mask = torch.tril( causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
scores_decay = scores_decay.masked_fill(~causal_mask, 0) scores_decay = scores_decay.masked_fill(~causal_mask, 0)
out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype), out = torch.einsum(
rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)) "bchls,bhcs,bcshp->bclhp", scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)
)
state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1"))
out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange( out_prev = (
C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out torch.einsum("bclhn,bchpn->bclhp", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out
)
out = out + out_prev out = out + out_prev
out = rearrange(out, "b c l h p -> b (c l) h p") out = rearrange(out, "b c l h p -> b (c l) h p")
if D is not None: if D is not None:
...@@ -61,12 +63,7 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): ...@@ -61,12 +63,7 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
def get_configs(): def get_configs():
iter_params = dict( iter_params = dict(block_M=[64, 128, 256], block_N=[32, 64], block_K=[64, 128, 256], block_Dstate=[128], num_stages=[1, 2, 3, 4, 5])
block_M=[64, 128, 256],
block_N=[32, 64],
block_K=[64, 128, 256],
block_Dstate=[128],
num_stages=[1, 2, 3, 4, 5])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
...@@ -77,19 +74,21 @@ def get_configs(): ...@@ -77,19 +74,21 @@ def get_configs():
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}, },
) )
def chunk_scan_fwd(batch, def chunk_scan_fwd(
seqlen, batch,
chunk_size, seqlen,
ngroups, chunk_size,
nheads, ngroups,
headdim, nheads,
dstate, headdim,
block_M=64, dstate,
block_N=64, block_M=64,
block_K=64, block_N=64,
block_Dstate=128, block_K=64,
num_stages=2, block_Dstate=128,
threads=128): num_stages=2,
threads=128,
):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
nchunks = T.ceildiv(seqlen, chunk_size) nchunks = T.ceildiv(seqlen, chunk_size)
...@@ -97,20 +96,20 @@ def chunk_scan_fwd(batch, ...@@ -97,20 +96,20 @@ def chunk_scan_fwd(batch,
@T.prim_func @T.prim_func
def main( def main(
cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore
x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore
dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore
prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore
D: T.Tensor((nheads), dtype), # type: ignore D: T.Tensor((nheads), dtype), # type: ignore
Output: T.Tensor((batch, seqlen, nheads, headdim), dtype) # type: ignore Output: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore
): ):
with T.Kernel( with T.Kernel(nheads, T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), batch * nchunks, threads=threads) as (
nheads, bz,
T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), bx,
batch * nchunks, by,
threads=threads) as (bz, bx, by): ):
acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) acc_o = T.alloc_fragment((block_M, block_N), accum_dtype)
acc_o_shared = T.alloc_shared((block_M, block_N), dtype) acc_o_shared = T.alloc_shared((block_M, block_N), dtype)
cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn") cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn")
...@@ -136,27 +135,32 @@ def chunk_scan_fwd(batch, ...@@ -136,27 +135,32 @@ def chunk_scan_fwd(batch,
m_idx = bx // T.ceildiv(headdim, block_N) m_idx = bx // T.ceildiv(headdim, block_N)
n_idx = bx % T.ceildiv(headdim, block_N) n_idx = bx % T.ceildiv(headdim, block_N)
T.annotate_layout({ T.annotate_layout(
acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared), {
cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared),
x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared) cb_shared: tilelang.layout.make_swizzled_layout(cb_shared),
}) x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared),
}
)
T.no_set_max_nreg() T.no_set_max_nreg()
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M], T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M : (m_idx + 1) * block_M], dA_cs_m_shared)
dA_cs_m_shared)
T.copy(dA_cs_m_shared, dA_cs_m_local) T.copy(dA_cs_m_shared, dA_cs_m_local)
T.clear(acc_o) T.clear(acc_o)
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p) scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p)
T.copy( T.copy(
C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + C[
(m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared) batch_idx,
T.copy( chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M,
prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N, bz // (nheads // ngroups),
0:block_Dstate], prev_state_shared) 0:block_Dstate,
],
C_shared,
)
T.copy(prev_states[batch_idx, chunk_idx, bz, n_idx * block_N : (n_idx + 1) * block_N, 0:block_Dstate], prev_state_shared)
T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True) T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_o[i, j] *= scale_m_local[i] acc_o[i, j] *= scale_m_local[i]
...@@ -165,34 +169,47 @@ def chunk_scan_fwd(batch, ...@@ -165,34 +169,47 @@ def chunk_scan_fwd(batch,
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy( T.copy(
cb[batch_idx, chunk_idx, bz // (nheads // ngroups), cb[
m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K], batch_idx,
cb_shared) chunk_idx,
bz // (nheads // ngroups),
m_idx * block_M : (m_idx + 1) * block_M,
k * block_K : (k + 1) * block_K,
],
cb_shared,
)
T.copy(cb_shared, cb_local) T.copy(cb_shared, cb_local)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cs_k_shared)
dA_cs_k_shared)
T.copy(dA_cs_k_shared, dA_cs_k_local) T.copy(dA_cs_k_shared, dA_cs_k_local)
for i, j in T.Parallel(block_M, block_K): for i, j in T.Parallel(block_M, block_K):
cb_local[i, cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p)
j] = cb_local[i, T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared)
j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p)
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared)
T.copy(dt_shared, dt_local) T.copy(dt_shared, dt_local)
for i, j in T.Parallel(block_M, block_K): for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] *= dt_local[j] cb_local[i, j] *= dt_local[j]
for i, j in T.Parallel(block_M, block_K): for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, cb_local[i, j], 0)
cb_local[i, j], 0)
T.copy( T.copy(
x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + x[
(k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared) batch_idx,
chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K,
bz,
n_idx * block_N : (n_idx + 1) * block_N,
],
x_shared,
)
T.gemm(cb_local, x_shared, acc_o) T.gemm(cb_local, x_shared, acc_o)
D_local[0] = D[bz] D_local[0] = D[bz]
T.copy( T.copy(
x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + x[
(m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N], batch_idx,
x_residual_shared) chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M,
bz,
n_idx * block_N : (n_idx + 1) * block_N,
],
x_residual_shared,
)
T.copy(x_residual_shared, x_residual_local) T.copy(x_residual_shared, x_residual_local)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_o[i, j] += x_residual_local[i, j] * D_local[0] acc_o[i, j] += x_residual_local[i, j] * D_local[0]
...@@ -200,27 +217,40 @@ def chunk_scan_fwd(batch, ...@@ -200,27 +217,40 @@ def chunk_scan_fwd(batch,
T.copy(acc_o, acc_o_shared) T.copy(acc_o, acc_o_shared)
T.copy( T.copy(
acc_o_shared, acc_o_shared,
Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + Output[
(m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N]) batch_idx,
chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M,
bz,
n_idx * block_N : (n_idx + 1) * block_N,
],
)
return main return main
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size') parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument('--heads', type=int, default=80, help='heads') parser.add_argument("--heads", type=int, default=80, help="heads")
parser.add_argument('--groups', type=int, default=1, help='groups') parser.add_argument("--groups", type=int, default=1, help="groups")
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') parser.add_argument("--seq_len", type=int, default=4096, help="sequence length")
parser.add_argument('--chunk_size', type=int, default=256, help='chunk size') parser.add_argument("--chunk_size", type=int, default=256, help="chunk size")
parser.add_argument('--dim', type=int, default=64, help='dim') parser.add_argument("--dim", type=int, default=64, help="dim")
parser.add_argument('--dstate', type=int, default=128, help='dstate') parser.add_argument("--dstate", type=int, default=128, help="dstate")
parser.add_argument('--tune', action='store_true', help='tune configs') parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args() args = parser.parse_args()
batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate batch, heads, groups, seq_len, chunk_size, dim, dstate = (
args.batch,
args.heads,
args.groups,
args.seq_len,
args.chunk_size,
args.dim,
args.dstate,
)
total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate
if (not args.tune): if not args.tune:
kernel = chunk_scan_fwd( kernel = chunk_scan_fwd(
batch, batch,
seq_len, seq_len,
...@@ -234,7 +264,8 @@ if __name__ == "__main__": ...@@ -234,7 +264,8 @@ if __name__ == "__main__":
block_K=64, block_K=64,
block_Dstate=128, block_Dstate=128,
num_stages=2, num_stages=2,
threads=128) threads=128,
)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
......
...@@ -10,6 +10,7 @@ import itertools ...@@ -10,6 +10,7 @@ import itertools
def chunk_state_triton(B, x, dt, dA_cumsum): def chunk_state_triton(B, x, dt, dA_cumsum):
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd
return _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=False) return _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=False)
...@@ -41,46 +42,33 @@ def ref_program(B, x, dt, dA_cumsum): ...@@ -41,46 +42,33 @@ def ref_program(B, x, dt, dA_cumsum):
x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size) x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size) B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)) decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), dt.to(x.dtype), x)
dt.to(x.dtype), x)
def get_configs(): def get_configs():
iter_params = dict( iter_params = dict(block_M=[64, 128], block_N=[32, 64, 128], block_K=[32, 64], num_stages=[1, 2, 3, 4, 5])
block_M=[64, 128], block_N=[32, 64, 128], block_K=[32, 64], num_stages=[1, 2, 3, 4, 5])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[4]) @tilelang.jit(out_idx=[4])
def chunk_state_fwd(batch, def chunk_state_fwd(
seqlen, batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, block_M=64, block_N=64, block_K=64, num_stages=2, threads=128
chunk_size, ):
ngroups,
nheads,
headdim,
dstate,
block_M=64,
block_N=64,
block_K=64,
num_stages=2,
threads=128):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
nchunks = T.ceildiv(seqlen, chunk_size) nchunks = T.ceildiv(seqlen, chunk_size)
p = 1.44269504 p = 1.44269504
@T.prim_func @T.prim_func
def main(B: T.Tensor((batch, seqlen, ngroups, dstate), dtype), x: T.Tensor( def main(
(batch, seqlen, nheads, headdim), dtype), dt: T.Tensor( B: T.Tensor((batch, seqlen, ngroups, dstate), dtype),
(batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor( x: T.Tensor((batch, seqlen, nheads, headdim), dtype),
(batch, nheads, nchunks, chunk_size), dtype), Output: T.Tensor( dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype),
(batch, nchunks, nheads, headdim, dstate), dtype)): dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype),
with T.Kernel( Output: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype),
nheads, ):
T.ceildiv(headdim, block_M) * T.ceildiv(dstate, block_N), with T.Kernel(nheads, T.ceildiv(headdim, block_M) * T.ceildiv(dstate, block_N), batch * nchunks, threads=threads) as (bz, bx, by):
batch * nchunks,
threads=threads) as (bz, bx, by):
x_shared = T.alloc_shared((block_K, block_M), dtype) x_shared = T.alloc_shared((block_K, block_M), dtype)
x_local = T.alloc_fragment((block_K, block_M), dtype) x_local = T.alloc_fragment((block_K, block_M), dtype)
xt_local = T.alloc_fragment((block_M, block_K), dtype) xt_local = T.alloc_fragment((block_M, block_K), dtype)
...@@ -101,20 +89,24 @@ def chunk_state_fwd(batch, ...@@ -101,20 +89,24 @@ def chunk_state_fwd(batch,
m_idx = bx // T.ceildiv(dstate, block_N) m_idx = bx // T.ceildiv(dstate, block_N)
n_idx = bx % T.ceildiv(dstate, block_N) n_idx = bx % T.ceildiv(dstate, block_N)
T.annotate_layout({ T.annotate_layout(
x_shared: tilelang.layout.make_swizzled_layout(x_shared), {x_shared: tilelang.layout.make_swizzled_layout(x_shared), acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared)}
acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared) )
})
dA_cs_last[0] = dA_cumsum[batch_idx, bz, chunk_idx, chunk_size - 1] dA_cs_last[0] = dA_cumsum[batch_idx, bz, chunk_idx, chunk_size - 1]
T.clear(acc_o) T.clear(acc_o)
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy( T.copy(
x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + x[
(k + 1) * block_K, bz, m_idx * block_M:(m_idx + 1) * block_M], x_shared) batch_idx,
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K,
dA_cumsum_shared) bz,
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared) m_idx * block_M : (m_idx + 1) * block_M,
],
x_shared,
)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cumsum_shared)
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared)
T.copy(dA_cumsum_shared, dA_cumsum_local) T.copy(dA_cumsum_shared, dA_cumsum_local)
T.copy(dt_shared, dt_local) T.copy(dt_shared, dt_local)
for i in T.Parallel(block_K): for i in T.Parallel(block_K):
...@@ -123,47 +115,50 @@ def chunk_state_fwd(batch, ...@@ -123,47 +115,50 @@ def chunk_state_fwd(batch,
for i, j in T.Parallel(block_M, block_K): for i, j in T.Parallel(block_M, block_K):
xt_local[i, j] = x_local[j, i] * scale[j] xt_local[i, j] = x_local[j, i] * scale[j]
T.copy( T.copy(
B[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + B[
(k + 1) * block_K, bz // (nheads // ngroups), batch_idx,
n_idx * block_N:(n_idx + 1) * block_N], B_shared) chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K,
bz // (nheads // ngroups),
n_idx * block_N : (n_idx + 1) * block_N,
],
B_shared,
)
T.gemm(xt_local, B_shared, acc_o) T.gemm(xt_local, B_shared, acc_o)
T.copy(acc_o, acc_o_shared) T.copy(acc_o, acc_o_shared)
T.copy( T.copy(
acc_o_shared, acc_o_shared,
Output[batch_idx, chunk_idx, bz, m_idx * block_M:(m_idx + 1) * block_M, Output[batch_idx, chunk_idx, bz, m_idx * block_M : (m_idx + 1) * block_M, n_idx * block_N : (n_idx + 1) * block_N],
n_idx * block_N:(n_idx + 1) * block_N]) )
return main return main
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size') parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument('--heads', type=int, default=80, help='heads') parser.add_argument("--heads", type=int, default=80, help="heads")
parser.add_argument('--groups', type=int, default=1, help='groups') parser.add_argument("--groups", type=int, default=1, help="groups")
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') parser.add_argument("--seq_len", type=int, default=4096, help="sequence length")
parser.add_argument('--chunk_size', type=int, default=256, help='chunk size') parser.add_argument("--chunk_size", type=int, default=256, help="chunk size")
parser.add_argument('--dim', type=int, default=64, help='dim') parser.add_argument("--dim", type=int, default=64, help="dim")
parser.add_argument('--dstate', type=int, default=128, help='dstate') parser.add_argument("--dstate", type=int, default=128, help="dstate")
parser.add_argument('--tune', action='store_true', help='tune configs') parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args() args = parser.parse_args()
batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate batch, heads, groups, seq_len, chunk_size, dim, dstate = (
args.batch,
args.heads,
args.groups,
args.seq_len,
args.chunk_size,
args.dim,
args.dstate,
)
total_flops = 2 * batch * seq_len * heads * dim * dstate total_flops = 2 * batch * seq_len * heads * dim * dstate
if (not args.tune): if not args.tune:
kernel = chunk_state_fwd( kernel = chunk_state_fwd(
batch, batch, seq_len, chunk_size, groups, heads, dim, dstate, block_M=64, block_N=128, block_K=64, num_stages=4, threads=128
seq_len, )
chunk_size,
groups,
heads,
dim,
dstate,
block_M=64,
block_N=128,
block_K=64,
num_stages=4,
threads=128)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
......
...@@ -13,13 +13,12 @@ def chunk_retention_fwd_kernel( ...@@ -13,13 +13,12 @@ def chunk_retention_fwd_kernel(
H, H,
DK, DK,
DV, DV,
dtype: str = 'float16', dtype: str = "float16",
scale: float = None, scale: float = None,
) -> torch.Tensor: ) -> torch.Tensor:
if scale is None: if scale is None:
scale = DK**-0.5 scale = DK**-0.5
accum_dtype = 'float' accum_dtype = "float"
chunk_size = 64 chunk_size = 64
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
...@@ -30,16 +29,16 @@ def chunk_retention_fwd_kernel( ...@@ -30,16 +29,16 @@ def chunk_retention_fwd_kernel(
@T.prim_func @T.prim_func
def chunk_retention_fwd( def chunk_retention_fwd(
Q: T.Tensor([B, S, H, DK], dtype), # type: ignore Q: T.Tensor([B, S, H, DK], dtype), # type: ignore
K: T.Tensor([B, S, H, DK], dtype), # type: ignore K: T.Tensor([B, S, H, DK], dtype), # type: ignore
V: T.Tensor([B, S, H, DV], dtype), # type: ignore V: T.Tensor([B, S, H, DV], dtype), # type: ignore
O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore
): ):
with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh):
i_b = i_bh // H i_b = i_bh // H
i_h = i_bh % H i_h = i_bh % H
log_decay = T.alloc_var('float32') log_decay = T.alloc_var("float32")
log_decay = T.log2(1 - T.exp2(-5. - 1. * i_h)) # Head-specific log decay log_decay = T.log2(1 - T.exp2(-5.0 - 1.0 * i_h)) # Head-specific log decay
q = T.alloc_shared([chunk_size, BK], dtype) q = T.alloc_shared([chunk_size, BK], dtype)
k = T.alloc_shared([chunk_size, BK], dtype) k = T.alloc_shared([chunk_size, BK], dtype)
...@@ -56,14 +55,12 @@ def chunk_retention_fwd_kernel( ...@@ -56,14 +55,12 @@ def chunk_retention_fwd_kernel(
for i in T.Pipelined(0, NT): for i in T.Pipelined(0, NT):
for row, col in T.Parallel(chunk_size, BK): for row, col in T.Parallel(chunk_size, BK):
q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale
T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) T.copy(K[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k)
T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v) T.copy(V[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v)
T.gemm(q, k, s, clear_accum=True, transpose_B=True) T.gemm(q, k, s, clear_accum=True, transpose_B=True)
for row, col in T.Parallel(chunk_size, chunk_size): for row, col in T.Parallel(chunk_size, chunk_size):
s_shared[row, s_shared[row, col] = T.if_then_else(row >= col, s[row, col] * T.exp2((row - col) * log_decay), 0)
col] = T.if_then_else(row >= col, s[row, col] * T.exp2(
(row - col) * log_decay), 0)
T.copy(h, h_shared) T.copy(h, h_shared)
T.gemm(q, h_shared, o, clear_accum=True) T.gemm(q, h_shared, o, clear_accum=True)
...@@ -75,9 +72,7 @@ def chunk_retention_fwd_kernel( ...@@ -75,9 +72,7 @@ def chunk_retention_fwd_kernel(
v[row, col] = v[row, col] * T.exp2((chunk_size - row - 1) * log_decay) v[row, col] = v[row, col] * T.exp2((chunk_size - row - 1) * log_decay)
for row, col in T.Parallel(BK, BV): for row, col in T.Parallel(BK, BV):
h[row, col] = T.exp2(chunk_size * log_decay) * h[row, col] h[row, col] = T.exp2(chunk_size * log_decay) * h[row, col]
T.copy( T.copy(o, O[i_k, i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV])
o, O[i_k, i_b, i * chunk_size:(i + 1) * chunk_size, i_h,
i_v * BV:(i_v + 1) * BV])
T.gemm(k, v, h, transpose_A=True) T.gemm(k, v, h, transpose_A=True)
return chunk_retention_fwd return chunk_retention_fwd
...@@ -89,24 +84,24 @@ def postprocess(o): ...@@ -89,24 +84,24 @@ def postprocess(o):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--B', type=int, default=8, help='Batch size') parser.add_argument("--B", type=int, default=8, help="Batch size")
parser.add_argument('--S', type=int, default=4096, help='Seq len') parser.add_argument("--S", type=int, default=4096, help="Seq len")
parser.add_argument('--H', type=int, default=32, help='Num heads') parser.add_argument("--H", type=int, default=32, help="Num heads")
parser.add_argument('--D', type=int, default=128, help='Head dim') parser.add_argument("--D", type=int, default=128, help="Head dim")
args = parser.parse_args() args = parser.parse_args()
B, S, H, D = args.B, args.S, args.H, args.D B, S, H, D = args.B, args.S, args.H, args.D
total_flops = 2.0 * B * S * S * H * D # causal total_flops = 2.0 * B * S * S * H * D # causal
q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16)
k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16)
v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16)
kernel = chunk_retention_fwd_kernel(B, S, H, D, D) kernel = chunk_retention_fwd_kernel(B, S, H, D, D)
t = do_bench(lambda: postprocess(kernel(q, k, v)), warmup=25, rep=100) t = do_bench(lambda: postprocess(kernel(q, k, v)), warmup=25, rep=100)
print(f'Tilelang latency: {t:.3f} ms') print(f"Tilelang latency: {t:.3f} ms")
print(f'Tilelang TFLOPs: {total_flops/t * 1e-9}') print(f"Tilelang TFLOPs: {total_flops / t * 1e-9}")
if __name__ == '__main__': if __name__ == "__main__":
main() main()
...@@ -15,12 +15,11 @@ from tilelang.profiler import do_bench ...@@ -15,12 +15,11 @@ from tilelang.profiler import do_bench
@tilelang.jit(out_idx=[3]) @tilelang.jit(out_idx=[3])
def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_size): def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_size):
block_M = 64 block_M = 64
block_N = 64 block_N = 64
num_stages = 2 num_stages = 2
threads = 128 threads = 128
scale = (1.0 / dim)**0.5 * 1.44269504 scale = (1.0 / dim) ** 0.5 * 1.44269504
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
seq_blocks = (seq_len + block_M - 1) // block_M seq_blocks = (seq_len + block_M - 1) // block_M
...@@ -30,15 +29,13 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz ...@@ -30,15 +29,13 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
offset_shape = count_shape + [slash_size] offset_shape = count_shape + [slash_size]
index_shape = count_shape + [vertical_size] index_shape = count_shape + [vertical_size]
vertical_size_round, slash_size_round = tilelang.next_power_of_2( vertical_size_round, slash_size_round = tilelang.next_power_of_2(vertical_size), tilelang.next_power_of_2(slash_size)
vertical_size), tilelang.next_power_of_2(slash_size)
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
int_dtype = "int32" int_dtype = "int32"
def kernel_func(block_M, block_N, num_stages, threads): def kernel_func(block_M, block_N, num_stages, threads):
@T.macro @T.macro
def Prefetch( def Prefetch(
K: T.Tensor(shape, dtype), K: T.Tensor(shape, dtype),
...@@ -53,32 +50,30 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz ...@@ -53,32 +50,30 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
): ):
with T.attr("default", "async_scope", 1): with T.attr("default", "async_scope", 1):
for i, j in T.Parallel(block_N, dim): for i, j in T.Parallel(block_N, dim):
K_shared[i, j] = T.if_then_else(k + i < column_count, K_shared[i, j] = T.if_then_else(k + i < column_count, K[bz, by, column_index[k + i], j], 0)
K[bz, by, column_index[k + i], j], 0)
with T.attr("default", "async_scope", 1): with T.attr("default", "async_scope", 1):
for i, j in T.Parallel(block_N, dim): for i, j in T.Parallel(block_N, dim):
V_shared[i, j] = T.if_then_else(k + i < column_count, V_shared[i, j] = T.if_then_else(k + i < column_count, V[bz, by, column_index[k + i], j], 0)
V[bz, by, column_index[k + i], j], 0)
T.ptx_commit_group() T.ptx_commit_group()
@T.macro @T.macro
def Compute( def Compute(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
k: T.int32, k: T.int32,
column_count: T.int32, column_count: T.int32,
Q_shared: T.SharedBuffer([block_M, dim], dtype), Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype), K_shared: T.SharedBuffer([block_N, dim], dtype),
V_shared: T.SharedBuffer([block_N, dim], dtype), V_shared: T.SharedBuffer([block_N, dim], dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
count: T.int32, count: T.int32,
): ):
T.ptx_wait_group(count) T.ptx_wait_group(count)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
...@@ -108,17 +103,16 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz ...@@ -108,17 +103,16 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
@T.prim_func @T.prim_func
def vs_sparse_flashattn_ws( def vs_sparse_flashattn_ws(
Q: T.Tensor(shape, dtype), Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype), K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype), V: T.Tensor(shape, dtype),
Output: T.Tensor(shape, dtype), Output: T.Tensor(shape, dtype),
BlockCount: T.Tensor(count_shape, int_dtype), BlockCount: T.Tensor(count_shape, int_dtype),
BlockOffset: T.Tensor(offset_shape, int_dtype), BlockOffset: T.Tensor(offset_shape, int_dtype),
ColumnCount: T.Tensor(count_shape, int_dtype), ColumnCount: T.Tensor(count_shape, int_dtype),
ColumnIndex: T.Tensor(index_shape, int_dtype), ColumnIndex: T.Tensor(index_shape, int_dtype),
): ):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bc, by, bz): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bc, by, bz):
bx = T.ceildiv(seq_len, block_M) - 1 - bc bx = T.ceildiv(seq_len, block_M) - 1 - bc
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([2, block_N, dim], dtype) K_shared = T.alloc_shared([2, block_N, dim], dtype)
...@@ -143,9 +137,11 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz ...@@ -143,9 +137,11 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
T.create_list_of_mbarrier([128] * 9) T.create_list_of_mbarrier([128] * 9)
T.annotate_layout({ T.annotate_layout(
O_shared: tilelang.layout.make_swizzled_layout(O_shared), {
}) O_shared: tilelang.layout.make_swizzled_layout(O_shared),
}
)
block_count[0] = BlockCount[bz, by, bx] block_count[0] = BlockCount[bz, by, bx]
column_count[0] = ColumnCount[bz, by, bx] column_count[0] = ColumnCount[bz, by, bx]
...@@ -162,15 +158,15 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz ...@@ -162,15 +158,15 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
if tid >= 128: if tid >= 128:
T.annotate_producer_reg_dealloc() T.annotate_producer_reg_dealloc()
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.mbarrier_arrive(mbarrier=8) T.mbarrier_arrive(mbarrier=8)
for bi in T.serial(block_count[0]): for bi in T.serial(block_count[0]):
k = block_offset[bi] k = block_offset[bi]
T.mbarrier_wait_parity(mbarrier=bi % 2 + 4, parity=(((bi & 3) >> 1) ^ 1)) T.mbarrier_wait_parity(mbarrier=bi % 2 + 4, parity=(((bi & 3) >> 1) ^ 1))
T.copy(K[bz, by, k:k + block_N, :], K_shared[bi % 2, :, :]) T.copy(K[bz, by, k : k + block_N, :], K_shared[bi % 2, :, :])
T.mbarrier_arrive(mbarrier=bi % 2) T.mbarrier_arrive(mbarrier=bi % 2)
T.mbarrier_wait_parity(mbarrier=bi % 2 + 6, parity=(((bi & 3) >> 1) ^ 1)) T.mbarrier_wait_parity(mbarrier=bi % 2 + 6, parity=(((bi & 3) >> 1) ^ 1))
T.copy(V[bz, by, k:k + block_N, :], V_shared[bi % 2, :, :]) T.copy(V[bz, by, k : k + block_N, :], V_shared[bi % 2, :, :])
T.mbarrier_arrive(mbarrier=bi % 2 + 2) T.mbarrier_arrive(mbarrier=bi % 2 + 2)
else: else:
T.annotate_consumer_reg_alloc() T.annotate_consumer_reg_alloc()
...@@ -181,16 +177,10 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz ...@@ -181,16 +177,10 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
for bi in T.serial(block_count[0]): for bi in T.serial(block_count[0]):
k = block_offset[bi] k = block_offset[bi]
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k + j, 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
T.mbarrier_wait_parity(mbarrier=bi % 2, parity=((bi & 3) >> 1)) T.mbarrier_wait_parity(mbarrier=bi % 2, parity=((bi & 3) >> 1))
T.gemm( T.gemm(Q_shared, K_shared[bi % 2, :, :], acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
Q_shared,
K_shared[bi % 2, :, :],
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
T.mbarrier_arrive(mbarrier=bi % 2 + 4) T.mbarrier_arrive(mbarrier=bi % 2 + 4)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
...@@ -200,20 +190,15 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz ...@@ -200,20 +190,15 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] = acc_o[i, j] * scores_scale[i] acc_o[i, j] = acc_o[i, j] * scores_scale[i]
T.copy(acc_s, acc_s_cast) T.copy(acc_s, acc_s_cast)
T.mbarrier_wait_parity(mbarrier=bi % 2 + 2, parity=(((bi & 3) >> 1))) T.mbarrier_wait_parity(mbarrier=bi % 2 + 2, parity=((bi & 3) >> 1))
T.gemm( T.gemm(acc_s_cast, V_shared[bi % 2, :, :], acc_o, policy=T.GemmWarpPolicy.FullRow)
acc_s_cast,
V_shared[bi % 2, :, :],
acc_o,
policy=T.GemmWarpPolicy.FullRow)
T.mbarrier_arrive(mbarrier=bi % 2 + 6) T.mbarrier_arrive(mbarrier=bi % 2 + 6)
...@@ -223,38 +208,85 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz ...@@ -223,38 +208,85 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
if column_count[0] != 0: if column_count[0] != 0:
Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count[0], 0, bz, Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count[0], 0, bz, by)
by)
for bi in T.serial(T.ceildiv(column_count[0], block_N) - 1): for bi in T.serial(T.ceildiv(column_count[0], block_N) - 1):
k = bi * block_N k = bi * block_N
if bi % 2 == 0: if bi % 2 == 0:
Prefetch(K, V, K_shared_2, V_shared_2, column_index, Prefetch(K, V, K_shared_2, V_shared_2, column_index, column_count[0], k + block_N, bz, by)
column_count[0], k + block_N, bz, by)
Compute(
Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, k, acc_s,
column_count[0], Q_shared, K_shared_1, V_shared_1, acc_s_cast,
scores_scale, scores_sum, logsum, 1) acc_o,
scores_max,
scores_max_prev,
k,
column_count[0],
Q_shared,
K_shared_1,
V_shared_1,
scores_scale,
scores_sum,
logsum,
1,
)
else: else:
Prefetch(K, V, K_shared_1, V_shared_1, column_index, Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count[0], k + block_N, bz, by)
column_count[0], k + block_N, bz, by)
Compute(
Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, k, acc_s,
column_count[0], Q_shared, K_shared_2, V_shared_2, acc_s_cast,
scores_scale, scores_sum, logsum, 1) acc_o,
scores_max,
scores_max_prev,
k,
column_count[0],
Q_shared,
K_shared_2,
V_shared_2,
scores_scale,
scores_sum,
logsum,
1,
)
if T.ceildiv(column_count[0], block_N) % 2 == 0: if T.ceildiv(column_count[0], block_N) % 2 == 0:
Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, Compute(
T.ceildiv(column_count[0], block_N) * block_N - block_N, acc_s,
column_count[0], Q_shared, K_shared_2, V_shared_2, scores_scale, acc_s_cast,
scores_sum, logsum, 0) acc_o,
scores_max,
scores_max_prev,
T.ceildiv(column_count[0], block_N) * block_N - block_N,
column_count[0],
Q_shared,
K_shared_2,
V_shared_2,
scores_scale,
scores_sum,
logsum,
0,
)
else: else:
Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, Compute(
T.ceildiv(column_count[0], block_N) * block_N - block_N, acc_s,
column_count[0], Q_shared, K_shared_1, V_shared_1, scores_scale, acc_s_cast,
scores_sum, logsum, 0) acc_o,
scores_max,
scores_max_prev,
T.ceildiv(column_count[0], block_N) * block_N - block_N,
column_count[0],
Q_shared,
K_shared_1,
V_shared_1,
scores_scale,
scores_sum,
logsum,
0,
)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return vs_sparse_flashattn_ws return vs_sparse_flashattn_ws
...@@ -470,11 +502,8 @@ def vertical_slash_sparse_attention( ...@@ -470,11 +502,8 @@ def vertical_slash_sparse_attention(
import os import os
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
sources = [ sources = [os.path.join(current_dir, "ops", "kernels.cpp"), os.path.join(current_dir, "ops", "vertical_slash_index.cu")]
os.path.join(current_dir, 'ops', 'kernels.cpp'), ops = load(name="convert", sources=sources, verbose=False)
os.path.join(current_dir, 'ops', 'vertical_slash_index.cu')
]
ops = load(name='convert', sources=sources, verbose=False)
convert_vertical_slash_indexes = ops.convert_vertical_slash_indexes convert_vertical_slash_indexes = ops.convert_vertical_slash_indexes
batch_size, num_heads, context_size, head_dim = query.shape batch_size, num_heads, context_size, head_dim = query.shape
pad = (block_size_M - context_size) & (block_size_M - 1) pad = (block_size_M - context_size) & (block_size_M - 1)
...@@ -485,15 +514,13 @@ def vertical_slash_sparse_attention( ...@@ -485,15 +514,13 @@ def vertical_slash_sparse_attention(
value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0]) value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0])
if head_dim not in [16, 32, 64, 128, 256, 512]: if head_dim not in [16, 32, 64, 128, 256, 512]:
target_dim = 2**math.ceil(math.log2(head_dim)) - head_dim target_dim = 2 ** math.ceil(math.log2(head_dim)) - head_dim
query = torch.nn.functional.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0]) query = torch.nn.functional.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0])
key = torch.nn.functional.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0]) key = torch.nn.functional.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0])
value = torch.nn.functional.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0]) value = torch.nn.functional.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0])
v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort( v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0]
dim=-1, descending=False)[0] s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0]
s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(
dim=-1, descending=True)[0]
seqlens = torch.tensor([context_size] * query.shape[0], dtype=torch.int32, device=query.device) seqlens = torch.tensor([context_size] * query.shape[0], dtype=torch.int32, device=query.device)
sm_scale = head_dim**-0.5 sm_scale = head_dim**-0.5
...@@ -506,8 +533,7 @@ def vertical_slash_sparse_attention( ...@@ -506,8 +533,7 @@ def vertical_slash_sparse_attention(
block_size_N, block_size_N,
) )
tl_kernel = _tl_vs_sparse_flashattn(batch_size, num_heads, context_size, head_dim, tl_kernel = _tl_vs_sparse_flashattn(batch_size, num_heads, context_size, head_dim, v_idx.shape[2], s_idx.shape[2])
v_idx.shape[2], s_idx.shape[2])
def run(is_triton: bool = True): def run(is_triton: bool = True):
if is_triton: if is_triton:
...@@ -525,8 +551,7 @@ def vertical_slash_sparse_attention( ...@@ -525,8 +551,7 @@ def vertical_slash_sparse_attention(
block_size_N, block_size_N,
) )
else: else:
out = tl_kernel(query, key, value, block_count, block_offset, column_count, out = tl_kernel(query, key, value, block_count, block_offset, column_count, column_index)
column_index)
return out[..., :context_size, :head_dim] return out[..., :context_size, :head_dim]
return run return run
...@@ -536,8 +561,7 @@ def sum_all_diagonal_matrix(mat: torch.tensor): ...@@ -536,8 +561,7 @@ def sum_all_diagonal_matrix(mat: torch.tensor):
b, h, n, m = mat.shape b, h, n, m = mat.shape
zero_mat = torch.zeros((b, h, n, n)).to(mat.device) # Zero matrix used for padding zero_mat = torch.zeros((b, h, n, n)).to(mat.device) # Zero matrix used for padding
mat_padded = torch.cat((zero_mat, mat, zero_mat), -1) # pads the matrix on left and right mat_padded = torch.cat((zero_mat, mat, zero_mat), -1) # pads the matrix on left and right
mat_strided = mat_padded.as_strided( mat_strided = mat_padded.as_strided((1, 1, n, n + m), (1, n * (2 * n + m), 2 * n + m + 1, 1)) # Change the strides
(1, 1, n, n + m), (1, n * (2 * n + m), 2 * n + m + 1, 1)) # Change the strides
sum_diags = torch.sum(mat_strided, 2) # Sums the resulting matrix's columns sum_diags = torch.sum(mat_strided, 2) # Sums the resulting matrix's columns
return sum_diags[:, :, 1:] return sum_diags[:, :, 1:]
...@@ -559,24 +583,23 @@ def main(argv=None): ...@@ -559,24 +583,23 @@ def main(argv=None):
vertical_size, slash_size = args.vertical_size, args.slash_size vertical_size, slash_size = args.vertical_size, args.slash_size
torch.manual_seed(0) torch.manual_seed(0)
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
q_len = SEQ_LEN q_len = SEQ_LEN
vertical_size, slash_size = min(q_len, vertical_size), min(q_len, slash_size) vertical_size, slash_size = min(q_len, vertical_size), min(q_len, slash_size)
last_q = 64 last_q = 64
qk = torch.einsum('bhmk, bhnk -> bhmn', q[:, :, -last_q:, :], k) qk = torch.einsum("bhmk, bhnk -> bhmn", q[:, :, -last_q:, :], k)
arange = torch.arange(last_q, device="cuda") arange = torch.arange(last_q, device="cuda")
qk[:, :, :, -last_q:] = torch.where(arange[None, None, :, None] >= arange[None, None, None, :], qk[:, :, :, -last_q:] = torch.where(arange[None, None, :, None] >= arange[None, None, None, :], qk[:, :, :, -last_q:], -torch.inf)
qk[:, :, :, -last_q:], -torch.inf)
qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32) qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
vertical = qk.sum(-2, keepdim=True) vertical = qk.sum(-2, keepdim=True)
vertical[..., :30] = torch.inf vertical[..., :30] = torch.inf
vertical_topk = torch.topk(vertical, vertical_size, -1).indices vertical_topk = torch.topk(vertical, vertical_size, -1).indices
slash = sum_all_diagonal_matrix(qk)[..., :-last_q + 1] slash = sum_all_diagonal_matrix(qk)[..., : -last_q + 1]
slash[..., -30:] = torch.inf slash[..., -30:] = torch.inf
slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices
......
...@@ -45,7 +45,7 @@ def rms_norm(M, N, blk_m): ...@@ -45,7 +45,7 @@ def rms_norm(M, N, blk_m):
A_local = T.alloc_fragment((blk_m, N), dtype) A_local = T.alloc_fragment((blk_m, N), dtype)
A_powsum = T.alloc_fragment((blk_m,), dtype) A_powsum = T.alloc_fragment((blk_m,), dtype)
T.copy(A[bx * blk_m:(bx + 1) * blk_m, :], A_shared) T.copy(A[bx * blk_m : (bx + 1) * blk_m, :], A_shared)
T.copy(A_shared, A_local) T.copy(A_shared, A_local)
for i, j in T.Parallel(blk_m, N): for i, j in T.Parallel(blk_m, N):
A_pow_local[i, j] = A_local[i, j] * A_local[i, j] A_pow_local[i, j] = A_local[i, j] * A_local[i, j]
...@@ -54,7 +54,7 @@ def rms_norm(M, N, blk_m): ...@@ -54,7 +54,7 @@ def rms_norm(M, N, blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12)
for i, j in T.Parallel(blk_m, N): for i, j in T.Parallel(blk_m, N):
A_local[i, j] *= A_powsum[i] A_local[i, j] *= A_powsum[i]
T.copy(A_local, B[bx * blk_m:(bx + 1) * blk_m, :]) T.copy(A_local, B[bx * blk_m : (bx + 1) * blk_m, :])
return main return main
......
...@@ -45,7 +45,7 @@ def rms_norm(M, N, blk_m): ...@@ -45,7 +45,7 @@ def rms_norm(M, N, blk_m):
A_local = T.alloc_fragment((blk_m, N), dtype) A_local = T.alloc_fragment((blk_m, N), dtype)
A_powsum = T.alloc_fragment((blk_m,), dtype) A_powsum = T.alloc_fragment((blk_m,), dtype)
T.copy(A[bx * blk_m:(bx + 1) * blk_m, :], A_shared) T.copy(A[bx * blk_m : (bx + 1) * blk_m, :], A_shared)
T.copy(A_shared, A_local) T.copy(A_shared, A_local)
for i, j in T.Parallel(blk_m, N): for i, j in T.Parallel(blk_m, N):
A_pow_local[i, j] = A_local[i, j] * A_local[i, j] A_pow_local[i, j] = A_local[i, j] * A_local[i, j]
...@@ -54,7 +54,7 @@ def rms_norm(M, N, blk_m): ...@@ -54,7 +54,7 @@ def rms_norm(M, N, blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12)
for i, j in T.Parallel(blk_m, N): for i, j in T.Parallel(blk_m, N):
A_local[i, j] *= A_powsum[i] A_local[i, j] *= A_powsum[i]
T.copy(A_local, B[bx * blk_m:(bx + 1) * blk_m, :]) T.copy(A_local, B[bx * blk_m : (bx + 1) * blk_m, :])
return main return main
......
...@@ -20,8 +20,8 @@ def softmax_kernel( ...@@ -20,8 +20,8 @@ def softmax_kernel(
@T.prim_func @T.prim_func
def main( def main(
X: T.Tensor([M, N], dtype), X: T.Tensor([M, N], dtype),
Y: T.Tensor([M, N], dtype), Y: T.Tensor([M, N], dtype),
): ):
with T.Kernel(M, threads=128) as (i_m): with T.Kernel(M, threads=128) as (i_m):
x = T.alloc_fragment([BN], dtype) x = T.alloc_fragment([BN], dtype)
...@@ -33,7 +33,7 @@ def softmax_kernel( ...@@ -33,7 +33,7 @@ def softmax_kernel(
T.fill(lse, -T.infinity(accum_dtype)) T.fill(lse, -T.infinity(accum_dtype))
for i_n in T.Pipelined(0, NN): for i_n in T.Pipelined(0, NN):
T.copy(X[i_m, i_n * BN:(i_n + 1) * BN], x) T.copy(X[i_m, i_n * BN : (i_n + 1) * BN], x)
T.reduce_max(x, max_x, dim=0, clear=True) T.reduce_max(x, max_x, dim=0, clear=True)
...@@ -45,12 +45,12 @@ def softmax_kernel( ...@@ -45,12 +45,12 @@ def softmax_kernel(
lse[0] = max_x[0] * scale + T.log2(T.exp2(lse[0] - max_x[0] * scale) + sum_exp_x[0]) lse[0] = max_x[0] * scale + T.log2(T.exp2(lse[0] - max_x[0] * scale) + sum_exp_x[0])
for i_n in T.Pipelined(0, NN): for i_n in T.Pipelined(0, NN):
T.copy(X[i_m, i_n * BN:(i_n + 1) * BN], x) T.copy(X[i_m, i_n * BN : (i_n + 1) * BN], x)
for j in T.Parallel(BN): for j in T.Parallel(BN):
y[j] = T.exp2(x[j] * scale - lse[0]) y[j] = T.exp2(x[j] * scale - lse[0])
T.copy(y, Y[i_m, i_n * BN:(i_n + 1) * BN]) T.copy(y, Y[i_m, i_n * BN : (i_n + 1) * BN])
return main return main
...@@ -69,4 +69,4 @@ t1 = do_bench(lambda: X.softmax(dim=1), warmup=25, rep=100) ...@@ -69,4 +69,4 @@ t1 = do_bench(lambda: X.softmax(dim=1), warmup=25, rep=100)
t2 = do_bench(lambda: kernel(X), warmup=25, rep=100) t2 = do_bench(lambda: kernel(X), warmup=25, rep=100)
print(f"torch latency: {t1:.3f} ms") print(f"torch latency: {t1:.3f} ms")
print(f"TileLang latency: {t2:.3f} ms") print(f"TileLang latency: {t2:.3f} ms")
print(f"Speedup: {t1/t2:.3f}x") print(f"Speedup: {t1 / t2:.3f}x")
...@@ -11,10 +11,9 @@ from tilelang.intrinsics.mfma_layout import ( ...@@ -11,10 +11,9 @@ from tilelang.intrinsics.mfma_layout import (
) )
def make_mfma_load_base_layout(dtype: str = "float16", def make_mfma_load_base_layout(
matrix: Literal["A", "B"] = "A", dtype: str = "float16", matrix: Literal["A", "B"] = "A", k_dim: int = 16, transposed: bool = False
k_dim: int = 16, ) -> T.Fragment:
transposed: bool = False) -> T.Fragment:
""" """
Create a layout function for storing MFMA results into a fragment buffer. Create a layout function for storing MFMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mfma_store_layout` to This layout is used in conjunction with `inverse_mfma_store_layout` to
...@@ -72,12 +71,10 @@ def make_mfma_load_base_layout(dtype: str = "float16", ...@@ -72,12 +71,10 @@ def make_mfma_load_base_layout(dtype: str = "float16",
# so the b matrix expected a transposed basic layout # so the b matrix expected a transposed basic layout
transform_func: Callable = None transform_func: Callable = None
if matrix == "A": if matrix == "A":
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i)
j, i)
micro_size_s, micro_size_r = micro_size_x, micro_size_k micro_size_s, micro_size_r = micro_size_x, micro_size_k
elif matrix == "B": elif matrix == "B":
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i)
j, i)
micro_size_s, micro_size_r = micro_size_k, micro_size_y micro_size_s, micro_size_r = micro_size_k, micro_size_y
else: else:
raise ValueError(f"Unsupported matrix {matrix}") raise ValueError(f"Unsupported matrix {matrix}")
...@@ -120,14 +117,11 @@ print(base_layout) ...@@ -120,14 +117,11 @@ print(base_layout)
plot_layout(base_layout, name="base_layout") plot_layout(base_layout, name="base_layout")
# warp layout 32x32 # warp layout 32x32
warp_layout = base_layout.repeat([warp_rows, warp_cols], warp_layout = base_layout.repeat([warp_rows, warp_cols], repeat_on_thread=False, lower_dim_first=False)
repeat_on_thread=False,
lower_dim_first=False)
print(warp_layout) print(warp_layout)
plot_layout(warp_layout, name="warp_layout") plot_layout(warp_layout, name="warp_layout")
# block layout 64x32 # block layout 64x32
block_layout = warp_layout.repeat([block_rows, 1], repeat_on_thread=True, block_layout = warp_layout.repeat([block_rows, 1], repeat_on_thread=True, lower_dim_first=True).replicate(block_cols)
lower_dim_first=True).replicate(block_cols)
print(block_layout) print(block_layout)
plot_layout(block_layout, name="block_layout") plot_layout(block_layout, name="block_layout")
...@@ -5,9 +5,7 @@ from tvm.tir import IndexMap ...@@ -5,9 +5,7 @@ from tvm.tir import IndexMap
from tilelang.intrinsics.utils import get_mma_micro_size from tilelang.intrinsics.utils import get_mma_micro_size
def make_mma_load_base_layout(dtype: str = "float16", def make_mma_load_base_layout(dtype: str = "float16", matrix: Literal["A", "B"] = "A", transposed: bool = False) -> T.Fragment:
matrix: Literal["A", "B"] = "A",
transposed: bool = False) -> T.Fragment:
""" """
Create a layout function for storing MMA results into a fragment buffer. Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to This layout is used in conjunction with `inverse_mma_store_layout` to
...@@ -36,6 +34,7 @@ def make_mma_load_base_layout(dtype: str = "float16", ...@@ -36,6 +34,7 @@ def make_mma_load_base_layout(dtype: str = "float16",
shared_16x16_to_mma_32x8_layout_sr_b, shared_16x16_to_mma_32x8_layout_sr_b,
shared_16x32_to_mma_32x16_layout_sr_b, shared_16x32_to_mma_32x16_layout_sr_b,
) )
assert matrix in ["A", "B"], "matrix should be either A or B" assert matrix in ["A", "B"], "matrix should be either A or B"
dtype_bits = DataType(dtype).bits dtype_bits = DataType(dtype).bits
# s represents spatial axis # s represents spatial axis
...@@ -67,12 +66,10 @@ def make_mma_load_base_layout(dtype: str = "float16", ...@@ -67,12 +66,10 @@ def make_mma_load_base_layout(dtype: str = "float16",
# so the b matrix expected a transposed basic layout # so the b matrix expected a transposed basic layout
transform_func: Callable = None transform_func: Callable = None
if matrix == "A": if matrix == "A":
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i)
j, i)
micro_size_s, micro_size_r = micro_size_x, micro_size_k micro_size_s, micro_size_r = micro_size_x, micro_size_k
elif matrix == "B": elif matrix == "B":
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i)
j, i)
micro_size_s, micro_size_r = micro_size_k, micro_size_y micro_size_s, micro_size_r = micro_size_k, micro_size_y
else: else:
raise ValueError(f"Unsupported matrix {matrix}") raise ValueError(f"Unsupported matrix {matrix}")
......
...@@ -7,12 +7,11 @@ import tilelang.language as T ...@@ -7,12 +7,11 @@ import tilelang.language as T
# if not specified, it will be inferred from the input tensors during compile time # if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
def matmul_relu_kernel( def matmul_relu_kernel(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......
...@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): ...@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK # N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
False,
dtype=torch.bool,
device=x.device)
dense_mask.scatter_(-1, sparse_index, True) dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block: if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True dense_mask[:, :, -2:, :] = True
...@@ -30,15 +27,17 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F ...@@ -30,15 +27,17 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
@tilelang.jit( @tilelang.jit(
out_idx=[4], pass_configs={ out_idx=[4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_causal): def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_causal):
block_M = 64 block_M = 64
block_N = 64 block_N = 64
num_stages = 0 num_stages = 0
threads = 128 threads = 128
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim] q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim] kv_shape = [batch, heads, seq_kv, dim]
block_mask_shape = [batch, heads, downsample_len, downsample_len] block_mask_shape = [batch, heads, downsample_len, downsample_len]
...@@ -48,16 +47,15 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c ...@@ -48,16 +47,15 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
block_mask_dtype = "int8" block_mask_dtype = "int8"
def kernel_func(block_M, block_N, num_stages, threads): def kernel_func(block_M, block_N, num_stages, threads):
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -83,19 +81,19 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c ...@@ -83,19 +81,19 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def main( def main(
Q: T.Tensor(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
Output: T.Tensor(q_shape, dtype), Output: T.Tensor(q_shape, dtype),
): ):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -112,7 +110,7 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c ...@@ -112,7 +110,7 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
block_mask = T.alloc_local([downsample_len], block_mask_dtype) block_mask = T.alloc_local([downsample_len], block_mask_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -124,33 +122,25 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c ...@@ -124,33 +122,25 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
if block_mask[k] != 0: if block_mask[k] != 0:
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
if is_causal: if is_causal:
past_len = seq_kv - seq_q past_len = seq_kv - seq_q
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else( acc_s[i, j] = T.if_then_else(bx * block_M + i + past_len >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
bx * block_M + i + past_len >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else: else:
T.clear(acc_s) T.clear(acc_s)
T.gemm( T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
Q_shared,
K_shared, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main return main
...@@ -165,44 +155,40 @@ def test_topk_sparse_attention(): ...@@ -165,44 +155,40 @@ def test_topk_sparse_attention():
torch.manual_seed(0) torch.manual_seed(0)
# Create inputs # Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
sm_scale = 1.0 / (D_HEAD**0.5) sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level) # Create sparse mask (downsampled to block level)
downsample_factor = BLOCK downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor) downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.float16)
device='cuda',
dtype=torch.float16)
x_ds[:, :, :, 0] = 100 x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
# Run tilelang kernel # Run tilelang kernel
kernel = blocksparse_flashattn( kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
tilelang_output = kernel(q, k, v, block_mask.to(torch.int8)) tilelang_output = kernel(q, k, v, block_mask.to(torch.int8))
# Compute reference # Compute reference
# Expand block mask to full attention matrix # Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda"))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
# PyTorch reference implementation # PyTorch reference implementation
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float('-inf')) attn = attn.masked_fill(~full_mask, float("-inf"))
attn = F.softmax(attn, dim=-1) attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
print("ref_output", ref_output) print("ref_output", ref_output)
print("tilelang_output", tilelang_output) print("tilelang_output", tilelang_output)
# Verify accuracy # Verify accuracy
assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), \ assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), "TileLang output doesn't match reference"
"TileLang output doesn't match reference"
print("Pass topk sparse attention test with qlen == klen") print("Pass topk sparse attention test with qlen == klen")
...@@ -215,42 +201,40 @@ def test_topk_sparse_attention_qlen_lt_klen(): ...@@ -215,42 +201,40 @@ def test_topk_sparse_attention_qlen_lt_klen():
torch.manual_seed(0) torch.manual_seed(0)
# Create inputs. # Create inputs.
q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.float16) q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.float16) k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.float16) v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.float16)
sm_scale = 1.0 / (D_HEAD**0.5) sm_scale = 1.0 / (D_HEAD**0.5)
downsample_factor = BLOCK downsample_factor = BLOCK
downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension
x_ds = torch.randn( x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.float16)
BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.float16)
# Force the first column to be high so that the first block is always selected. # Force the first column to be high so that the first block is always selected.
x_ds[:, :, :, 0] = 100 x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
kernel = blocksparse_flashattn( kernel = blocksparse_flashattn(BATCH, N_HEADS, Q_LEN, K_LEN, D_HEAD, downsample_len, is_causal=True)
BATCH, N_HEADS, Q_LEN, K_LEN, D_HEAD, downsample_len, is_causal=True)
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
tilelang_output = kernel(q, k, v, block_mask.to(torch.int8)) tilelang_output = kernel(q, k, v, block_mask.to(torch.int8))
past_len = K_LEN - Q_LEN past_len = K_LEN - Q_LEN
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool() full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool()
full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] full_mask_full = full_mask_full[..., :K_LEN, :K_LEN]
effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN)
i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1)
j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN)
causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN) causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN)
final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN)
attn = attn.masked_fill(~final_mask, float('-inf')) attn = attn.masked_fill(~final_mask, float("-inf"))
attn = F.softmax(attn, dim=-1) attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
print("ref_output", ref_output) print("ref_output", ref_output)
print("tilelang_output", tilelang_output) print("tilelang_output", tilelang_output)
......
...@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): ...@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK # N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
False,
dtype=torch.bool,
device=x.device)
dense_mask.scatter_(-1, sparse_index, True) dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block: if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True dense_mask[:, :, -2:, :] = True
...@@ -54,7 +51,6 @@ def _fwd_kernel_inner( ...@@ -54,7 +51,6 @@ def _fwd_kernel_inner(
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
): ):
mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n)
if mask_val == True: if mask_val == True:
...@@ -69,7 +65,7 @@ def _fwd_kernel_inner( ...@@ -69,7 +65,7 @@ def _fwd_kernel_inner(
qk *= sm_scale qk *= sm_scale
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float('-inf')) qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf"))
m_ij = tl.maximum(m_i, tl.max(qk, 1)) m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None] qk -= m_ij[:, None]
...@@ -149,7 +145,7 @@ def _fwd_kernel( ...@@ -149,7 +145,7 @@ def _fwd_kernel(
v_ptrs = V + off_v v_ptrs = V + off_v
mask_ptrs = block_mask_ptr + start_m * stride_bmm mask_ptrs = block_mask_ptr + start_m * stride_bmm
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
...@@ -185,24 +181,12 @@ def _fwd_kernel( ...@@ -185,24 +181,12 @@ def _fwd_kernel(
acc = acc * l_recip acc = acc * l_recip
acc = acc.to(Out.dtype.element_ty) acc = acc.to(Out.dtype.element_ty)
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[ off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
None, :] * stride_od
out_ptrs = Out + off_o out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX)
def _forward(ctx, def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None):
q,
k,
v,
block_sparse_mask,
sm_scale,
BLOCK_M=64,
BLOCK_N=64,
num_warps=None,
num_stages=1,
out=None):
assert q.shape[-1] == k.shape[-1] == v.shape[-1] assert q.shape[-1] == k.shape[-1] == v.shape[-1]
assert k.shape[2] == v.shape[2] assert k.shape[2] == v.shape[2]
o = out if out is not None else torch.empty_like(q).contiguous() o = out if out is not None else torch.empty_like(q).contiguous()
...@@ -247,7 +231,6 @@ def _forward(ctx, ...@@ -247,7 +231,6 @@ def _forward(ctx,
class _sparse_attention(torch.autograd.Function): class _sparse_attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, block_sparse_dense, sm_scale): def forward(ctx, q, k, v, block_sparse_dense, sm_scale):
# shape constraints # shape constraints
...@@ -271,9 +254,9 @@ def test_topk_sparse_attention(): ...@@ -271,9 +254,9 @@ def test_topk_sparse_attention():
torch.manual_seed(0) torch.manual_seed(0)
# Create inputs # Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
sm_scale = 1.0 / (D_HEAD**0.5) sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level) # Create sparse mask (downsampled to block level)
...@@ -281,9 +264,7 @@ def test_topk_sparse_attention(): ...@@ -281,9 +264,7 @@ def test_topk_sparse_attention():
downsample_len = math.ceil(SEQ_LEN / downsample_factor) downsample_len = math.ceil(SEQ_LEN / downsample_factor)
print("downsample_len", downsample_len) print("downsample_len", downsample_len)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
device='cuda',
dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100 x_ds[:, :, :, 0] = 100
print("x_ds.shape", x_ds.shape) print("x_ds.shape", x_ds.shape)
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
...@@ -295,22 +276,21 @@ def test_topk_sparse_attention(): ...@@ -295,22 +276,21 @@ def test_topk_sparse_attention():
# Compute reference # Compute reference
# Expand block mask to full attention matrix # Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda"))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
# PyTorch reference implementation # PyTorch reference implementation
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float('-inf')) attn = attn.masked_fill(~full_mask, float("-inf"))
attn = F.softmax(attn, dim=-1) attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
# print("ref_output", ref_output) # print("ref_output", ref_output)
# print("triton_output", triton_output) # print("triton_output", triton_output)
# Verify accuracy # Verify accuracy
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference"
"Triton output doesn't match reference"
print("Pass topk sparse attention test with qlen == klen") print("Pass topk sparse attention test with qlen == klen")
...@@ -322,16 +302,15 @@ def test_topk_sparse_attention_qlt_kl(): ...@@ -322,16 +302,15 @@ def test_topk_sparse_attention_qlt_kl():
torch.manual_seed(0) torch.manual_seed(0)
# Create inputs. # Create inputs.
q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
# softmax scale # softmax scale
sm_scale = 1.0 / (D_HEAD**0.5) sm_scale = 1.0 / (D_HEAD**0.5)
downsample_factor = BLOCK downsample_factor = BLOCK
downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension
x_ds = torch.randn( x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.bfloat16)
BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.bfloat16)
# Force the first column to be high so that the first block is always selected. # Force the first column to be high so that the first block is always selected.
x_ds[:, :, :, 0] = 100 x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
...@@ -340,26 +319,25 @@ def test_topk_sparse_attention_qlt_kl(): ...@@ -340,26 +319,25 @@ def test_topk_sparse_attention_qlt_kl():
past_len = K_LEN - Q_LEN past_len = K_LEN - Q_LEN
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool() full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool()
full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] full_mask_full = full_mask_full[..., :K_LEN, :K_LEN]
effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN)
i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1)
j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN)
causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN) causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN)
final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN)
attn = attn.masked_fill(~final_mask, float('-inf')) attn = attn.masked_fill(~final_mask, float("-inf"))
attn = F.softmax(attn, dim=-1) attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
# Verify accuracy. # Verify accuracy.
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference when qlen < klen"
"Triton output doesn't match reference when qlen < klen"
print("Pass topk sparse attention test with qlen < klen") print("Pass topk sparse attention test with qlen < klen")
......
...@@ -28,24 +28,22 @@ def matmul_sp( ...@@ -28,24 +28,22 @@ def matmul_sp(
@T.prim_func @T.prim_func
def main( def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype), A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // 8), 'uint8'), E: T.Tensor((M, K // 8), "uint8"),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // 8), 'uint8') E_shared = T.alloc_shared((block_M, block_K // 8), "uint8")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({ T.annotate_layout(
E: {
make_cutlass_metadata_layout( E: make_cutlass_metadata_layout(E, mma_dtype="float16", arch="9.0", block_k=block_K),
E, mma_dtype="float16", arch="9.0", block_k=block_K), E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", arch="9.0", block_k=block_K),
E_shared: }
make_cutlass_metadata_layout( )
E_shared, mma_dtype="float16", arch="9.0", block_k=block_K),
})
T.clear(C_local) T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // 8], E_shared) T.copy(E[by * block_M, k * block_K // 8], E_shared)
...@@ -57,7 +55,7 @@ def matmul_sp( ...@@ -57,7 +55,7 @@ def matmul_sp(
return main return main
def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device='cpu'): def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device="cpu"):
if shape[-1] % 4 != 0: if shape[-1] % 4 != 0:
raise ValueError("Last dimension must be divisible by 4 for 2:4 sparsity.") raise ValueError("Last dimension must be divisible by 4 for 2:4 sparsity.")
...@@ -102,9 +100,9 @@ def run_gemm_sp( ...@@ -102,9 +100,9 @@ def run_gemm_sp(
num_threads, num_threads,
) )
A = generate_2_to_4_sparse_tensor((M, K), dtype=torch.float16, device='cuda') A = generate_2_to_4_sparse_tensor((M, K), dtype=torch.float16, device="cuda")
A_sparse, E = compress_sm90(A, block_k=block_K, transposed=False) A_sparse, E = compress_sm90(A, block_k=block_K, transposed=False)
B = torch.randn((K, N), device='cuda', dtype=torch.float16) B = torch.randn((K, N), device="cuda", dtype=torch.float16)
C_sp = kernel(A_sparse, E, B).half() C_sp = kernel(A_sparse, E, B).half()
C = torch.matmul(A, B) C = torch.matmul(A, B)
......
...@@ -26,9 +26,9 @@ def tl_topk( ...@@ -26,9 +26,9 @@ def tl_topk(
@T.prim_func @T.prim_func
def topk_kernel( def topk_kernel(
logits: T.Tensor([M, N], dtype), logits: T.Tensor([M, N], dtype),
topk_gates: T.Tensor([M, topk], dtype), topk_gates: T.Tensor([M, topk], dtype),
topk_indices: T.Tensor([M, topk], "int32"), topk_indices: T.Tensor([M, topk], "int32"),
): ):
with T.Kernel(T.ceildiv(M, blk_m), threads=threads) as bx: with T.Kernel(T.ceildiv(M, blk_m), threads=threads) as bx:
logits_frag = T.alloc_fragment([blk_m, N], dtype=dtype) logits_frag = T.alloc_fragment([blk_m, N], dtype=dtype)
...@@ -43,15 +43,12 @@ def tl_topk( ...@@ -43,15 +43,12 @@ def tl_topk(
T.reduce_max(logits_frag, max_val, dim=1, clear=True) T.reduce_max(logits_frag, max_val, dim=1, clear=True)
for i, j in T.Parallel(blk_m, N): for i, j in T.Parallel(blk_m, N):
expand_max_idx[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], j, expand_max_idx[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], j, expand_max_idx[i, j])
expand_max_idx[i, j])
T.reduce_max(expand_max_idx, max_idx, dim=1, clear=True) T.reduce_max(expand_max_idx, max_idx, dim=1, clear=True)
for i, j in T.Parallel(blk_m, N): for i, j in T.Parallel(blk_m, N):
logits_frag[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], -10000.0, logits_frag[i, j])
logits_frag[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], -10000.0,
logits_frag[i, j])
for i in T.Parallel(blk_m): for i in T.Parallel(blk_m):
topk_gates[bx * blk_m + i, k] = max_val[i] topk_gates[bx * blk_m + i, k] = max_val[i]
...@@ -61,7 +58,6 @@ def tl_topk( ...@@ -61,7 +58,6 @@ def tl_topk(
def ref_program(logits, top_k): def ref_program(logits, top_k):
top_k_gates, top_k_indices = logits.topk(top_k, dim=1) top_k_gates, top_k_indices = logits.topk(top_k, dim=1)
return top_k_gates, top_k_indices.to(torch.int32) return top_k_gates, top_k_indices.to(torch.int32)
......
...@@ -7,15 +7,15 @@ import tilelang.language as T ...@@ -7,15 +7,15 @@ import tilelang.language as T
out_idx=[-1], out_idx=[-1],
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE: True, tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE: True,
tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS: "svg" tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS: "svg",
}) },
)
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
def gemm( def gemm(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
...@@ -49,12 +49,12 @@ def main(): ...@@ -49,12 +49,12 @@ def main():
print("All check passed.") print("All check passed.")
# print the layout visualization result and save figures to ./tmp. # print the layout visualization result and save figures to ./tmp.
''' """
C_local inferenced layout: C_local inferenced layout:
Shape: [32, 32] -> [8] Shape: [32, 32] -> [8]
Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2 Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2
Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2] Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2]
''' """
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment