Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -17,7 +17,7 @@ def is_pow_of_2(n):
def hadamard(b, n, dtype):
assert is_pow_of_2(n), "n must be a power of 2"
assert 2 <= n <= 32768, "n must be in [2, 32768]"
elem_size = {'float32': 4, 'float16': 2, 'bfloat16': 2}[dtype]
elem_size = {"float32": 4, "float16": 2, "bfloat16": 2}[dtype]
logN = int(math.log2(n))
threads = [0, 1, 1, 1, 2, 4, 8, 16, 32, 32, 128, 256, 256, 256, 256, 256][logN]
......@@ -40,23 +40,21 @@ def hadamard(b, n, dtype):
# print(f'{exchange_round=}')
@T.macro
def warp_shfl(local: T.Tensor((thread_elem,), dtype), buf: T.Tensor((thread_elem,), dtype),
round: int):
def warp_shfl(local: T.Tensor((thread_elem,), dtype), buf: T.Tensor((thread_elem,), dtype), round: int):
tx = T.get_thread_binding(0)
for i in T.serial(round):
tx_stride = 1 << i
another_tx = tx ^ tx_stride
sign = (
tx >> i
) & 1 # get i-th lowest bit of tx, which determines the operation type for shared[tx, :]
sign = (tx >> i) & 1 # get i-th lowest bit of tx, which determines the operation type for shared[tx, :]
for j in T.Pipelined(thread_elem, num_stages=1):
buf[j] = T.tvm_warp_shuffle(
0xffffffff, # mask of all threads
0xFFFFFFFF, # mask of all threads
local[j],
another_tx % warp_size,
warp_size,
warp_size)
warp_size,
)
local[j] = T.if_then_else(sign == 0, local[j] + buf[j], buf[j] - local[j])
@T.prim_func
......@@ -78,10 +76,8 @@ def hadamard(b, n, dtype):
for j in T.serial(chunknum):
chunkbase = j * chunksize
for k in T.serial(chunksize // 2):
local[chunkbase +
k] = local[chunkbase + k] + local[chunkbase + k + chunksize // 2]
local[chunkbase + k + chunksize //
2] = local[chunkbase + k] - 2 * local[chunkbase + k + chunksize // 2]
local[chunkbase + k] = local[chunkbase + k] + local[chunkbase + k + chunksize // 2]
local[chunkbase + k + chunksize // 2] = local[chunkbase + k] - 2 * local[chunkbase + k + chunksize // 2]
# 3. Hadamard inside warp, n<=512
# In warp level, we rely on warp shuffle to exchange data inside each warp, without using shared memory
......@@ -131,28 +127,27 @@ def ref_program(x: torch.Tensor):
assert x.ndim == 2
dim = x.shape[-1]
assert is_pow_of_2(dim)
return F.linear(
x, torch.tensor(scipy.linalg.hadamard(dim, dtype=float), dtype=x.dtype, device=x.device))
return F.linear(x, torch.tensor(scipy.linalg.hadamard(dim, dtype=float), dtype=x.dtype, device=x.device))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=64, help='Batch size')
parser.add_argument('--dim', type=int, default=32768, help='Dimension')
parser.add_argument("--batch", type=int, default=64, help="Batch size")
parser.add_argument("--dim", type=int, default=32768, help="Dimension")
args = parser.parse_args()
B, D = args.batch, args.dim
x = torch.randn((B, D), device='cuda')
kernel = hadamard(B, D, 'float32')
x = torch.randn((B, D), device="cuda")
kernel = hadamard(B, D, "float32")
y = kernel(x)
y_ref = ref_program(x)
torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2)
print('All tests passed.')
print("All tests passed.")
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
latency = profiler.do_bench(warmup=100)
print("Tile-lang: {:.2f} ms".format(latency))
if __name__ == '__main__':
if __name__ == "__main__":
main()
......@@ -9,6 +9,7 @@
"source": [
"import sys\n",
"from pathlib import Path\n",
"\n",
"sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n",
"import tilelang\n",
"import torch\n",
......@@ -61,7 +62,7 @@
" out_dtype: T.dtype = T.float32,\n",
" block_M: int = 128,\n",
" block_N: int = 128,\n",
" block_K: int = 32\n",
" block_K: int = 32,\n",
"):\n",
" M, K = A.shape\n",
" K, N = B.shape\n",
......@@ -94,8 +95,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n",
"C = gemm(A, B)\n",
"\n",
"# check output is correct\n",
......@@ -118,8 +119,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
"B = torch.randn(512, 1024, dtype=torch.float16, device='cuda')\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
"B = torch.randn(512, 1024, dtype=torch.float16, device=\"cuda\")\n",
"C = gemm(A, B, block_M=64, block_N=64)"
]
},
......@@ -218,8 +219,8 @@
"source": [
"@tilelang.lazy_jit\n",
"def gemm_dyn_K(\n",
" A: T.Tensor[[int, T.dyn['K']], T.float16], # noqa: F821\n",
" B: T.Tensor[[T.dyn['K'], int], T.float16], # noqa: F821\n",
" A: T.Tensor[[int, T.dyn[\"K\"]], T.float16], # noqa: F821\n",
" B: T.Tensor[[T.dyn[\"K\"], int], T.float16], # noqa: F821\n",
"):\n",
" M, K = A.shape\n",
" K, N = B.shape\n",
......@@ -265,8 +266,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n",
"C = gemm_dyn_K(A, B)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
......@@ -295,18 +296,17 @@
"source": [
"from typing import Any\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"def as_contingious(\n",
" A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]\n",
"):\n",
"def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):\n",
" M, N = A.shape\n",
" B = T.empty((M, N), A.dtype)\n",
" block_M = 128\n",
" block_N = 128\n",
" with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
" T.copy(\n",
" A[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N],\n",
" B[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N]\n",
" A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n",
" B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n",
" )\n",
" return B"
]
......@@ -318,7 +318,7 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 1024, device='cuda')\n",
"A = torch.randn(1024, 1024, device=\"cuda\")\n",
"B = as_contingious(A[::2, ::2])\n",
"B_ref = A[::2, ::2].contiguous()\n",
"torch.testing.assert_close(B, B_ref)"
......@@ -370,8 +370,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n",
"C = gemm_ptr(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
......@@ -416,8 +416,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n",
"C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
......@@ -496,18 +496,20 @@
"source": [
"from itertools import product\n",
"\n",
"\n",
"def get_configs():\n",
" return [\n",
" {\n",
" 'A': T.Tensor((1024, 1024), T.float32),\n",
" 'B': T.Tensor((1024, 1024), T.float32),\n",
" 'block_M': block_M,\n",
" 'block_N': block_N,\n",
" 'block_K': block_K,\n",
" \"A\": T.Tensor((1024, 1024), T.float32),\n",
" \"B\": T.Tensor((1024, 1024), T.float32),\n",
" \"block_M\": block_M,\n",
" \"block_N\": block_N,\n",
" \"block_K\": block_K,\n",
" }\n",
" for block_M, block_N, block_K in product([32, 64], repeat=3)\n",
" ]\n",
"\n",
"\n",
"gemm.par_compile(get_configs())"
]
},
......@@ -581,6 +583,7 @@
"def macro_with_ref(x: T.Ref):\n",
" x = 1 # noqa: F841\n",
"\n",
"\n",
"@T.prim_func\n",
"def foo(x: T.Tensor((2,))):\n",
" with T.Kernel(1) as _:\n",
......@@ -591,6 +594,7 @@
" idx = T.alloc_var(T.int32, 0)\n",
" macro_with_ref(x[idx])\n",
"\n",
"\n",
"foo"
]
},
......@@ -616,7 +620,7 @@
" A: T.Tensor[[T.dyn], Any],\n",
" fn,\n",
"):\n",
" N, = A.shape\n",
" (N,) = A.shape\n",
" B = T.empty((N,), dtype=A.dtype)\n",
" block_N = 128\n",
" with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n",
......@@ -624,6 +628,8 @@
" idx = bx * block_N + i\n",
" B[idx] = fn(A[idx])\n",
" return B\n",
"\n",
"\n",
"@T.macro\n",
"def add_one(x):\n",
" return x + 1"
......@@ -636,7 +642,7 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, device='cuda')\n",
"A = torch.randn(1024, device=\"cuda\")\n",
"B = element_wise(A, add_one)\n",
"B_ref = A + 1\n",
"torch.testing.assert_close(B, B_ref)"
......@@ -670,10 +676,11 @@
" var = var * 3 + 1\n",
" n31(x * 3 + 1, var)\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"def foo(A: T.Tensor[[1], T.int32], n: int):\n",
" with T.Kernel(1) as _:\n",
" n31(n, A[0])\n"
" n31(n, A[0])"
]
},
{
......@@ -694,7 +701,7 @@
}
],
"source": [
"A = torch.tensor([100], dtype=torch.int32, device='cuda')\n",
"A = torch.tensor([100], dtype=torch.int32, device=\"cuda\")\n",
"foo(A, 5)\n",
"A"
]
......@@ -745,12 +752,15 @@
"def sincos(x):\n",
" return T.sin(x), T.cos(x)\n",
"\n",
"\n",
"@T.prim_func\n",
"def foo():\n",
" with T.Kernel(32) as x:\n",
" s, c = sincos(x)\n",
" a = s + c # noqa: F841\n",
" b = s - c # noqa: F841\n",
"\n",
"\n",
"foo"
]
}
......
......@@ -9,6 +9,7 @@
"source": [
"import sys\n",
"from pathlib import Path\n",
"\n",
"sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n",
"import tilelang\n",
"import torch\n",
......@@ -61,7 +62,7 @@
" out_dtype: T.dtype = T.float32,\n",
" block_M: int = 128,\n",
" block_N: int = 128,\n",
" block_K: int = 32\n",
" block_K: int = 32,\n",
"):\n",
" M, K = A.shape\n",
" K, N = B.shape\n",
......@@ -94,8 +95,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n",
"C = gemm(A, B)\n",
"\n",
"# check output is correct\n",
......@@ -118,8 +119,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
"B = torch.randn(512, 1024, dtype=torch.float16, device='cuda')\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
"B = torch.randn(512, 1024, dtype=torch.float16, device=\"cuda\")\n",
"C = gemm(A, B, block_M=64, block_N=64)"
]
},
......@@ -218,8 +219,8 @@
"source": [
"@tilelang.lazy_jit\n",
"def gemm_dyn_K(\n",
" A: T.Tensor[[int, T.dyn['K']], T.float16], # noqa: F821\n",
" B: T.Tensor[[T.dyn['K'], int], T.float16], # noqa: F821\n",
" A: T.Tensor[[int, T.dyn[\"K\"]], T.float16], # noqa: F821\n",
" B: T.Tensor[[T.dyn[\"K\"], int], T.float16], # noqa: F821\n",
"):\n",
" M, K = A.shape\n",
" K, N = B.shape\n",
......@@ -265,8 +266,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n",
"C = gemm_dyn_K(A, B)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
......@@ -295,18 +296,17 @@
"source": [
"from typing import Any\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"def as_contingious(\n",
" A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]\n",
"):\n",
"def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):\n",
" M, N = A.shape\n",
" B = T.empty((M, N), A.dtype)\n",
" block_M = 128\n",
" block_N = 128\n",
" with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
" T.copy(\n",
" A[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N],\n",
" B[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N]\n",
" A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n",
" B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n",
" )\n",
" return B"
]
......@@ -318,7 +318,7 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 1024, device='cuda')\n",
"A = torch.randn(1024, 1024, device=\"cuda\")\n",
"B = as_contingious(A[::2, ::2])\n",
"B_ref = A[::2, ::2].contiguous()\n",
"torch.testing.assert_close(B, B_ref)"
......@@ -370,8 +370,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n",
"C = gemm_ptr(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
......@@ -416,8 +416,8 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n",
"A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
"B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n",
"C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
......@@ -496,18 +496,20 @@
"source": [
"from itertools import product\n",
"\n",
"\n",
"def get_configs():\n",
" return [\n",
" {\n",
" 'A': T.Tensor((1024, 1024), T.float32),\n",
" 'B': T.Tensor((1024, 1024), T.float32),\n",
" 'block_M': block_M,\n",
" 'block_N': block_N,\n",
" 'block_K': block_K,\n",
" \"A\": T.Tensor((1024, 1024), T.float32),\n",
" \"B\": T.Tensor((1024, 1024), T.float32),\n",
" \"block_M\": block_M,\n",
" \"block_N\": block_N,\n",
" \"block_K\": block_K,\n",
" }\n",
" for block_M, block_N, block_K in product([32, 64], repeat=3)\n",
" ]\n",
"\n",
"\n",
"gemm.par_compile(get_configs())"
]
},
......@@ -581,6 +583,7 @@
"def macro_with_ref(x: T.Ref):\n",
" x = 1 # noqa: F841\n",
"\n",
"\n",
"@T.prim_func\n",
"def foo(x: T.Tensor((2,))):\n",
" with T.Kernel(1) as _:\n",
......@@ -591,6 +594,7 @@
" idx = T.alloc_var(T.int32, 0)\n",
" macro_with_ref(x[idx])\n",
"\n",
"\n",
"foo"
]
},
......@@ -616,7 +620,7 @@
" A: T.Tensor[[T.dyn], Any],\n",
" fn,\n",
"):\n",
" N, = A.shape\n",
" (N,) = A.shape\n",
" B = T.empty((N,), dtype=A.dtype)\n",
" block_N = 128\n",
" with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n",
......@@ -624,6 +628,8 @@
" idx = bx * block_N + i\n",
" B[idx] = fn(A[idx])\n",
" return B\n",
"\n",
"\n",
"@T.macro\n",
"def add_one(x):\n",
" return x + 1"
......@@ -636,7 +642,7 @@
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, device='cuda')\n",
"A = torch.randn(1024, device=\"cuda\")\n",
"B = element_wise(A, add_one)\n",
"B_ref = A + 1\n",
"torch.testing.assert_close(B, B_ref)"
......@@ -670,10 +676,11 @@
" var = var * 3 + 1\n",
" n31(x * 3 + 1, var)\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"def foo(A: T.Tensor[[1], T.int32], n: int):\n",
" with T.Kernel(1) as _:\n",
" n31(n, A[0])\n"
" n31(n, A[0])"
]
},
{
......@@ -694,7 +701,7 @@
}
],
"source": [
"A = torch.tensor([100], dtype=torch.int32, device='cuda')\n",
"A = torch.tensor([100], dtype=torch.int32, device=\"cuda\")\n",
"foo(A, 5)\n",
"A"
]
......@@ -745,12 +752,15 @@
"def sincos(x):\n",
" return T.sin(x), T.cos(x)\n",
"\n",
"\n",
"@T.prim_func\n",
"def foo():\n",
" with T.Kernel(32) as x:\n",
" s, c = sincos(x)\n",
" a = s + c # noqa: F841\n",
" b = s - c # noqa: F841\n",
"\n",
"\n",
"foo"
]
}
......
......@@ -13,20 +13,20 @@ from typing import Optional, Tuple
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
}
)
def tl_fused_chunk_bwd_kernel(
B,
S,
H,
DK,
DV,
dtype: str = 'float16',
dtype: str = "float16",
scale: float = None,
) -> torch.Tensor:
if scale is None:
scale = DK**-0.5
accum_dtype = 'float'
accum_dtype = "float"
chunk_size = 64
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
......@@ -66,11 +66,13 @@ def tl_fused_chunk_bwd_kernel(
dh = T.alloc_fragment([BK, BV], accum_dtype)
dh_shared = T.alloc_shared([BK, BV], dtype)
T.annotate_layout({
T.annotate_layout(
{
dq_shared: tilelang.layout.make_swizzled_layout(dq_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared)
})
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
}
)
T.use_swizzle(10)
T.clear(h)
......@@ -78,10 +80,9 @@ def tl_fused_chunk_bwd_kernel(
# Calculate dQ
for i in T.Pipelined(0, NT):
T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k)
T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v)
T.copy(dO[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV],
do)
T.copy(K[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k)
T.copy(V[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v)
T.copy(dO[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], do)
T.gemm(do, v, ds, transpose_B=True, clear_accum=True)
for row, col in T.Parallel(chunk_size, chunk_size):
......@@ -94,29 +95,19 @@ def tl_fused_chunk_bwd_kernel(
for row, col in T.Parallel(chunk_size, BK):
dq[row, col] *= scale
T.copy(dq, dq_shared)
T.atomic_add(
dQ[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK],
dq_shared)
T.atomic_add(dQ[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], dq_shared)
# Calculate dK, dV (reversely)
for i in T.Pipelined(1, NT + 1):
start = NT - i
for row, col in T.Parallel(chunk_size, BK):
q[row, col] = Q[i_b, start * chunk_size + row, i_h, i_k * BK + col] * scale
T.copy(
K[i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
i_k * BK:(i_k + 1) * BK], k)
T.copy(
V[i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
i_v * BV:(i_v + 1) * BV], v)
T.copy(
dO[i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
i_v * BV:(i_v + 1) * BV], do)
T.copy(K[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k)
T.copy(V[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v)
T.copy(dO[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], do)
# Calculate dk
T.gemm(
v, do, ds, transpose_B=True, clear_accum=True
) # ds here actually means `s`, but we simply reuse the buffer `ds`
T.gemm(v, do, ds, transpose_B=True, clear_accum=True) # ds here actually means `s`, but we simply reuse the buffer `ds`
for row, col in T.Parallel(chunk_size, chunk_size):
ds_shared[row, col] = T.if_then_else(row <= col, ds[row, col], 0)
T.gemm(ds_shared, q, dk, clear_accum=True)
......@@ -134,13 +125,9 @@ def tl_fused_chunk_bwd_kernel(
T.gemm(q, do, dh, transpose_A=True)
T.copy(dk, dk_shared)
T.atomic_add(
dK[i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
i_k * BK:(i_k + 1) * BK], dk_shared)
T.atomic_add(dK[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], dk_shared)
T.copy(dv, dv_shared)
T.atomic_add(
dV[i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
i_v * BV:(i_v + 1) * BV], dv_shared)
T.atomic_add(dV[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], dv_shared)
return fused_chunk_linear_attn_bwd
......@@ -155,34 +142,31 @@ def tl_fused_chunk_bwd(Q, K, V, dO):
return dQ.to(torch.float16), dK.to(torch.float16), dV.to(torch.float16)
def ref_program(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]:
def ref_program(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]:
q, k, v = q.float(), k.float(), v.float()
if scale is None:
scale = q.shape[-1]**-0.5
scale = q.shape[-1] ** -0.5
chunk_size = 64
q = rearrange(q, 'b (n c) h d -> b h n c d', c=chunk_size) * scale
k = rearrange(k, 'b (n c) h d -> b h n c d', c=chunk_size)
v = rearrange(v, 'b (n c) h d -> b h n c d', c=chunk_size)
q = rearrange(q, "b (n c) h d -> b h n c d", c=chunk_size) * scale
k = rearrange(k, "b (n c) h d -> b h n c d", c=chunk_size)
v = rearrange(v, "b (n c) h d -> b h n c d", c=chunk_size)
kv = k.transpose(-1, -2) @ v
kv = kv.cumsum(2)
h = kv[:, :, -1, :, :]
kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
inter = q @ kv
intra = ((q @ k.transpose(-1, -2)).masked_fill_(
torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1),
0)) @ v
intra = (
(q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0)
) @ v
o = inter + intra
return rearrange(o, 'b h n c d -> b (n c) h d'), h
return rearrange(o, "b h n c d -> b (n c) h d"), h
def main(B=1, S=1024, H=16, D=128):
q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True)
k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True)
v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True)
do = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True)
k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True)
v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True)
do = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16)
# qk norm is necessary for linear attn
q = l2norm_fwd(q)[0].requires_grad_(True)
......@@ -193,30 +177,27 @@ def main(B=1, S=1024, H=16, D=128):
o_ref, _ = ref_program(q, k, v)
o_ref.backward(do, retain_graph=True)
assert torch.allclose(
dq, q.grad, atol=1e-2, rtol=1e-2), f'dq max err: {(dq - q.grad).abs().max()}'
assert torch.allclose(
dk, k.grad, atol=1e-2, rtol=1e-2), f'dk max err: {(dk - k.grad).abs().max()}'
assert torch.allclose(
dv, v.grad, atol=1e-2, rtol=1e-2), f'dv max err: {(dv - v.grad).abs().max()}'
print('Passed all tests!✅')
assert torch.allclose(dq, q.grad, atol=1e-2, rtol=1e-2), f"dq max err: {(dq - q.grad).abs().max()}"
assert torch.allclose(dk, k.grad, atol=1e-2, rtol=1e-2), f"dk max err: {(dk - k.grad).abs().max()}"
assert torch.allclose(dv, v.grad, atol=1e-2, rtol=1e-2), f"dv max err: {(dv - v.grad).abs().max()}"
print("Passed all tests!✅")
# Benchmark
q.grad = k.grad = v.grad = None
o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)
t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), backend='cupti')
t2 = do_bench(lambda: tl_fused_chunk_bwd(q, k, v, do), backend='cupti')
print(f'Triton latency: {t1:.3f} ms')
print(f'TileLang latency: {t2:.3f} ms')
print(f'Speedup: {t1/t2:.3f}x')
t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), backend="cupti")
t2 = do_bench(lambda: tl_fused_chunk_bwd(q, k, v, do), backend="cupti")
print(f"Triton latency: {t1:.3f} ms")
print(f"TileLang latency: {t2:.3f} ms")
print(f"Speedup: {t1 / t2:.3f}x")
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--B', type=int, default=8, help='Batch size')
parser.add_argument('--S', type=int, default=1024, help='Seq len')
parser.add_argument('--H', type=int, default=32, help='Num heads')
parser.add_argument('--D', type=int, default=128, help='Head dim')
parser.add_argument("--B", type=int, default=8, help="Batch size")
parser.add_argument("--S", type=int, default=1024, help="Seq len")
parser.add_argument("--H", type=int, default=32, help="Num heads")
parser.add_argument("--D", type=int, default=128, help="Head dim")
args = parser.parse_args()
main(args.B, args.S, args.H, args.D)
......@@ -14,20 +14,20 @@ from typing import Optional, Tuple
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
def tl_fused_chunk_fwd_kernel(
B,
S,
H,
DK,
DV,
dtype: str = 'float16',
dtype: str = "float16",
scale: float = None,
) -> torch.Tensor:
if scale is None:
scale = DK**-0.5
accum_dtype = 'float'
accum_dtype = "float"
chunk_size = 64
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
......@@ -42,7 +42,8 @@ def tl_fused_chunk_fwd_kernel(
K: T.Tensor([B, S, H, DK], dtype), # type: ignore
V: T.Tensor([B, S, H, DV], dtype), # type: ignore
O: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore
final_state: T.Tensor([B, H, DK, DV], accum_dtype)): # type: ignore
final_state: T.Tensor([B, H, DK, DV], accum_dtype),
): # type: ignore
with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh):
i_b = i_bh // H
i_h = i_bh % H
......@@ -65,8 +66,8 @@ def tl_fused_chunk_fwd_kernel(
for i in T.Pipelined(0, NT):
for row, col in T.Parallel(chunk_size, BK):
q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale
T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k)
T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v)
T.copy(K[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k)
T.copy(V[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v)
T.gemm(q, k, s, clear_accum=True, transpose_B=True)
for row, col in T.Parallel(chunk_size, chunk_size):
......@@ -77,12 +78,10 @@ def tl_fused_chunk_fwd_kernel(
T.gemm(k, v, h, transpose_A=True)
T.gemm(q, h_shared, o)
T.copy(o, o_shared)
T.atomic_add(
O[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV],
o_shared)
T.atomic_add(O[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], o_shared)
# Output final state
T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV])
T.copy(h, final_state[i_b, i_h, i_k * BK : (i_k + 1) * BK, i_v * BV : (i_v + 1) * BV])
return fused_chunk_linear_attn_fwd
......@@ -91,38 +90,35 @@ def tl_fused_chunk_fwd(q, k, v):
B, S, H, D = q.shape
kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D)
print(kernel.get_kernel_source())
o = torch.zeros((B, S, H, D), device='cuda', dtype=torch.float32)
o = torch.zeros((B, S, H, D), device="cuda", dtype=torch.float32)
h = kernel(q, k, v, o)
return o, h
def ref_program(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]:
def ref_program(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]:
q, k, v = q.float(), k.float(), v.float()
if scale is None:
scale = q.shape[-1]**-0.5
scale = q.shape[-1] ** -0.5
chunk_size = 64
q = rearrange(q, 'b (n c) h d -> b h n c d', c=chunk_size) * scale
k = rearrange(k, 'b (n c) h d -> b h n c d', c=chunk_size)
v = rearrange(v, 'b (n c) h d -> b h n c d', c=chunk_size)
q = rearrange(q, "b (n c) h d -> b h n c d", c=chunk_size) * scale
k = rearrange(k, "b (n c) h d -> b h n c d", c=chunk_size)
v = rearrange(v, "b (n c) h d -> b h n c d", c=chunk_size)
kv = k.transpose(-1, -2) @ v
kv = kv.cumsum(2)
h = kv[:, :, -1, :, :]
kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
inter = q @ kv
intra = ((q @ k.transpose(-1, -2)).masked_fill_(
torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1),
0)) @ v
intra = (
(q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0)
) @ v
o = inter + intra
return rearrange(o, 'b h n c d -> b (n c) h d'), h
return rearrange(o, "b h n c d -> b (n c) h d"), h
def main(B=1, S=512, H=16, D=128):
q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16)
k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16)
v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16)
# qk norm is necessary for linear attn
q, _ = l2norm_fwd(q)
......@@ -131,25 +127,23 @@ def main(B=1, S=512, H=16, D=128):
o, h = tl_fused_chunk_fwd(q, k, v)
o_ref, h_ref = ref_program(q, k, v)
assert torch.allclose(o, o_ref, atol=1e-2, rtol=1e-2), f'o max err: {(o - o_ref).abs().max()}'
assert torch.allclose(h, h_ref, atol=1e-2, rtol=1e-2), f'h max err: {(h - h_ref).abs().max()}'
print('Passed all tests!✅')
assert torch.allclose(o, o_ref, atol=1e-2, rtol=1e-2), f"o max err: {(o - o_ref).abs().max()}"
assert torch.allclose(h, h_ref, atol=1e-2, rtol=1e-2), f"h max err: {(h - h_ref).abs().max()}"
print("Passed all tests!✅")
t1 = do_bench(
lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False),
backend='cupti')
t2 = do_bench(lambda: tl_fused_chunk_fwd(q, k, v), backend='cupti')
print(f'Triton latency: {t1:.3f} ms')
print(f'TileLang latency: {t2:.3f} ms')
print(f'Speedup: {t1/t2:.3f}x')
t1 = do_bench(lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False), backend="cupti")
t2 = do_bench(lambda: tl_fused_chunk_fwd(q, k, v), backend="cupti")
print(f"Triton latency: {t1:.3f} ms")
print(f"TileLang latency: {t2:.3f} ms")
print(f"Speedup: {t1 / t2:.3f}x")
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--B', type=int, default=8, help='Batch size')
parser.add_argument('--S', type=int, default=1024, help='Seq len')
parser.add_argument('--H', type=int, default=32, help='Num heads')
parser.add_argument('--D', type=int, default=128, help='Head dim')
parser.add_argument("--B", type=int, default=8, help="Batch size")
parser.add_argument("--S", type=int, default=1024, help="Seq len")
parser.add_argument("--H", type=int, default=32, help="Num heads")
parser.add_argument("--D", type=int, default=128, help="Head dim")
args = parser.parse_args()
main(args.B, args.S, args.H, args.D)
......@@ -9,6 +9,7 @@ import itertools
def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D):
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd
out, _ = _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D)
return out
......@@ -43,14 +44,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :]
decay = torch.exp(dt_segment_sum)
scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s")
causal_mask = torch.tril(
torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
scores_decay = scores_decay.masked_fill(~causal_mask, 0)
out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype),
rearrange(x, "b (c s) h p -> b c s h p", c=nchunks))
out = torch.einsum(
"bchls,bhcs,bcshp->bclhp", scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)
)
state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1"))
out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange(
C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out
out_prev = (
torch.einsum("bclhn,bchpn->bclhp", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out
)
out = out + out_prev
out = rearrange(out, "b c l h p -> b (c l) h p")
if D is not None:
......@@ -61,12 +63,7 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
def get_configs():
iter_params = dict(
block_M=[64, 128, 256],
block_N=[32, 64],
block_K=[64, 128, 256],
block_Dstate=[128],
num_stages=[1, 2, 3, 4, 5])
iter_params = dict(block_M=[64, 128, 256], block_N=[32, 64], block_K=[64, 128, 256], block_Dstate=[128], num_stages=[1, 2, 3, 4, 5])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
......@@ -77,7 +74,8 @@ def get_configs():
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def chunk_scan_fwd(batch,
def chunk_scan_fwd(
batch,
seqlen,
chunk_size,
ngroups,
......@@ -89,7 +87,8 @@ def chunk_scan_fwd(batch,
block_K=64,
block_Dstate=128,
num_stages=2,
threads=128):
threads=128,
):
dtype = "float16"
accum_dtype = "float"
nchunks = T.ceildiv(seqlen, chunk_size)
......@@ -104,13 +103,13 @@ def chunk_scan_fwd(batch,
C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore
prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore
D: T.Tensor((nheads), dtype), # type: ignore
Output: T.Tensor((batch, seqlen, nheads, headdim), dtype) # type: ignore
Output: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore
):
with T.Kernel(nheads, T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), batch * nchunks, threads=threads) as (
bz,
bx,
by,
):
with T.Kernel(
nheads,
T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N),
batch * nchunks,
threads=threads) as (bz, bx, by):
acc_o = T.alloc_fragment((block_M, block_N), accum_dtype)
acc_o_shared = T.alloc_shared((block_M, block_N), dtype)
cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn")
......@@ -136,27 +135,32 @@ def chunk_scan_fwd(batch,
m_idx = bx // T.ceildiv(headdim, block_N)
n_idx = bx % T.ceildiv(headdim, block_N)
T.annotate_layout({
T.annotate_layout(
{
acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared),
cb_shared: tilelang.layout.make_swizzled_layout(cb_shared),
x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared)
})
x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared),
}
)
T.no_set_max_nreg()
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M],
dA_cs_m_shared)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M : (m_idx + 1) * block_M], dA_cs_m_shared)
T.copy(dA_cs_m_shared, dA_cs_m_local)
T.clear(acc_o)
for i in T.Parallel(block_M):
scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p)
T.copy(
C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
(m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared)
T.copy(
prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N,
0:block_Dstate], prev_state_shared)
C[
batch_idx,
chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M,
bz // (nheads // ngroups),
0:block_Dstate,
],
C_shared,
)
T.copy(prev_states[batch_idx, chunk_idx, bz, n_idx * block_N : (n_idx + 1) * block_N, 0:block_Dstate], prev_state_shared)
T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True)
for i, j in T.Parallel(block_M, block_N):
acc_o[i, j] *= scale_m_local[i]
......@@ -165,34 +169,47 @@ def chunk_scan_fwd(batch,
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
cb[batch_idx, chunk_idx, bz // (nheads // ngroups),
m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K],
cb_shared)
cb[
batch_idx,
chunk_idx,
bz // (nheads // ngroups),
m_idx * block_M : (m_idx + 1) * block_M,
k * block_K : (k + 1) * block_K,
],
cb_shared,
)
T.copy(cb_shared, cb_local)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K],
dA_cs_k_shared)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cs_k_shared)
T.copy(dA_cs_k_shared, dA_cs_k_local)
for i, j in T.Parallel(block_M, block_K):
cb_local[i,
j] = cb_local[i,
j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p)
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared)
cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p)
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared)
T.copy(dt_shared, dt_local)
for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] *= dt_local[j]
for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j,
cb_local[i, j], 0)
cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, cb_local[i, j], 0)
T.copy(
x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size +
(k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared)
x[
batch_idx,
chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K,
bz,
n_idx * block_N : (n_idx + 1) * block_N,
],
x_shared,
)
T.gemm(cb_local, x_shared, acc_o)
D_local[0] = D[bz]
T.copy(
x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
(m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N],
x_residual_shared)
x[
batch_idx,
chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M,
bz,
n_idx * block_N : (n_idx + 1) * block_N,
],
x_residual_shared,
)
T.copy(x_residual_shared, x_residual_local)
for i, j in T.Parallel(block_M, block_N):
acc_o[i, j] += x_residual_local[i, j] * D_local[0]
......@@ -200,27 +217,40 @@ def chunk_scan_fwd(batch,
T.copy(acc_o, acc_o_shared)
T.copy(
acc_o_shared,
Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
(m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N])
Output[
batch_idx,
chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M,
bz,
n_idx * block_N : (n_idx + 1) * block_N,
],
)
return main
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=80, help='heads')
parser.add_argument('--groups', type=int, default=1, help='groups')
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
parser.add_argument('--chunk_size', type=int, default=256, help='chunk size')
parser.add_argument('--dim', type=int, default=64, help='dim')
parser.add_argument('--dstate', type=int, default=128, help='dstate')
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=80, help="heads")
parser.add_argument("--groups", type=int, default=1, help="groups")
parser.add_argument("--seq_len", type=int, default=4096, help="sequence length")
parser.add_argument("--chunk_size", type=int, default=256, help="chunk size")
parser.add_argument("--dim", type=int, default=64, help="dim")
parser.add_argument("--dstate", type=int, default=128, help="dstate")
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate
batch, heads, groups, seq_len, chunk_size, dim, dstate = (
args.batch,
args.heads,
args.groups,
args.seq_len,
args.chunk_size,
args.dim,
args.dstate,
)
total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate
if (not args.tune):
if not args.tune:
kernel = chunk_scan_fwd(
batch,
seq_len,
......@@ -234,7 +264,8 @@ if __name__ == "__main__":
block_K=64,
block_Dstate=128,
num_stages=2,
threads=128)
threads=128,
)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
......
......@@ -10,6 +10,7 @@ import itertools
def chunk_state_triton(B, x, dt, dA_cumsum):
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd
return _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=False)
......@@ -41,46 +42,33 @@ def ref_program(B, x, dt, dA_cumsum):
x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype),
dt.to(x.dtype), x)
return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), dt.to(x.dtype), x)
def get_configs():
iter_params = dict(
block_M=[64, 128], block_N=[32, 64, 128], block_K=[32, 64], num_stages=[1, 2, 3, 4, 5])
iter_params = dict(block_M=[64, 128], block_N=[32, 64, 128], block_K=[32, 64], num_stages=[1, 2, 3, 4, 5])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[4])
def chunk_state_fwd(batch,
seqlen,
chunk_size,
ngroups,
nheads,
headdim,
dstate,
block_M=64,
block_N=64,
block_K=64,
num_stages=2,
threads=128):
def chunk_state_fwd(
batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, block_M=64, block_N=64, block_K=64, num_stages=2, threads=128
):
dtype = "float16"
accum_dtype = "float"
nchunks = T.ceildiv(seqlen, chunk_size)
p = 1.44269504
@T.prim_func
def main(B: T.Tensor((batch, seqlen, ngroups, dstate), dtype), x: T.Tensor(
(batch, seqlen, nheads, headdim), dtype), dt: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype), Output: T.Tensor(
(batch, nchunks, nheads, headdim, dstate), dtype)):
with T.Kernel(
nheads,
T.ceildiv(headdim, block_M) * T.ceildiv(dstate, block_N),
batch * nchunks,
threads=threads) as (bz, bx, by):
def main(
B: T.Tensor((batch, seqlen, ngroups, dstate), dtype),
x: T.Tensor((batch, seqlen, nheads, headdim), dtype),
dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype),
dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype),
Output: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype),
):
with T.Kernel(nheads, T.ceildiv(headdim, block_M) * T.ceildiv(dstate, block_N), batch * nchunks, threads=threads) as (bz, bx, by):
x_shared = T.alloc_shared((block_K, block_M), dtype)
x_local = T.alloc_fragment((block_K, block_M), dtype)
xt_local = T.alloc_fragment((block_M, block_K), dtype)
......@@ -101,20 +89,24 @@ def chunk_state_fwd(batch,
m_idx = bx // T.ceildiv(dstate, block_N)
n_idx = bx % T.ceildiv(dstate, block_N)
T.annotate_layout({
x_shared: tilelang.layout.make_swizzled_layout(x_shared),
acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared)
})
T.annotate_layout(
{x_shared: tilelang.layout.make_swizzled_layout(x_shared), acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared)}
)
dA_cs_last[0] = dA_cumsum[batch_idx, bz, chunk_idx, chunk_size - 1]
T.clear(acc_o)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size +
(k + 1) * block_K, bz, m_idx * block_M:(m_idx + 1) * block_M], x_shared)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K],
dA_cumsum_shared)
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared)
x[
batch_idx,
chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K,
bz,
m_idx * block_M : (m_idx + 1) * block_M,
],
x_shared,
)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cumsum_shared)
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared)
T.copy(dA_cumsum_shared, dA_cumsum_local)
T.copy(dt_shared, dt_local)
for i in T.Parallel(block_K):
......@@ -123,47 +115,50 @@ def chunk_state_fwd(batch,
for i, j in T.Parallel(block_M, block_K):
xt_local[i, j] = x_local[j, i] * scale[j]
T.copy(
B[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size +
(k + 1) * block_K, bz // (nheads // ngroups),
n_idx * block_N:(n_idx + 1) * block_N], B_shared)
B[
batch_idx,
chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K,
bz // (nheads // ngroups),
n_idx * block_N : (n_idx + 1) * block_N,
],
B_shared,
)
T.gemm(xt_local, B_shared, acc_o)
T.copy(acc_o, acc_o_shared)
T.copy(
acc_o_shared,
Output[batch_idx, chunk_idx, bz, m_idx * block_M:(m_idx + 1) * block_M,
n_idx * block_N:(n_idx + 1) * block_N])
Output[batch_idx, chunk_idx, bz, m_idx * block_M : (m_idx + 1) * block_M, n_idx * block_N : (n_idx + 1) * block_N],
)
return main
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=80, help='heads')
parser.add_argument('--groups', type=int, default=1, help='groups')
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
parser.add_argument('--chunk_size', type=int, default=256, help='chunk size')
parser.add_argument('--dim', type=int, default=64, help='dim')
parser.add_argument('--dstate', type=int, default=128, help='dstate')
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=80, help="heads")
parser.add_argument("--groups", type=int, default=1, help="groups")
parser.add_argument("--seq_len", type=int, default=4096, help="sequence length")
parser.add_argument("--chunk_size", type=int, default=256, help="chunk size")
parser.add_argument("--dim", type=int, default=64, help="dim")
parser.add_argument("--dstate", type=int, default=128, help="dstate")
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate
batch, heads, groups, seq_len, chunk_size, dim, dstate = (
args.batch,
args.heads,
args.groups,
args.seq_len,
args.chunk_size,
args.dim,
args.dstate,
)
total_flops = 2 * batch * seq_len * heads * dim * dstate
if (not args.tune):
if not args.tune:
kernel = chunk_state_fwd(
batch,
seq_len,
chunk_size,
groups,
heads,
dim,
dstate,
block_M=64,
block_N=128,
block_K=64,
num_stages=4,
threads=128)
batch, seq_len, chunk_size, groups, heads, dim, dstate, block_M=64, block_N=128, block_K=64, num_stages=4, threads=128
)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
......
......@@ -13,13 +13,12 @@ def chunk_retention_fwd_kernel(
H,
DK,
DV,
dtype: str = 'float16',
dtype: str = "float16",
scale: float = None,
) -> torch.Tensor:
if scale is None:
scale = DK**-0.5
accum_dtype = 'float'
accum_dtype = "float"
chunk_size = 64
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
......@@ -38,8 +37,8 @@ def chunk_retention_fwd_kernel(
with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh):
i_b = i_bh // H
i_h = i_bh % H
log_decay = T.alloc_var('float32')
log_decay = T.log2(1 - T.exp2(-5. - 1. * i_h)) # Head-specific log decay
log_decay = T.alloc_var("float32")
log_decay = T.log2(1 - T.exp2(-5.0 - 1.0 * i_h)) # Head-specific log decay
q = T.alloc_shared([chunk_size, BK], dtype)
k = T.alloc_shared([chunk_size, BK], dtype)
......@@ -56,14 +55,12 @@ def chunk_retention_fwd_kernel(
for i in T.Pipelined(0, NT):
for row, col in T.Parallel(chunk_size, BK):
q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale
T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k)
T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v)
T.copy(K[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k)
T.copy(V[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v)
T.gemm(q, k, s, clear_accum=True, transpose_B=True)
for row, col in T.Parallel(chunk_size, chunk_size):
s_shared[row,
col] = T.if_then_else(row >= col, s[row, col] * T.exp2(
(row - col) * log_decay), 0)
s_shared[row, col] = T.if_then_else(row >= col, s[row, col] * T.exp2((row - col) * log_decay), 0)
T.copy(h, h_shared)
T.gemm(q, h_shared, o, clear_accum=True)
......@@ -75,9 +72,7 @@ def chunk_retention_fwd_kernel(
v[row, col] = v[row, col] * T.exp2((chunk_size - row - 1) * log_decay)
for row, col in T.Parallel(BK, BV):
h[row, col] = T.exp2(chunk_size * log_decay) * h[row, col]
T.copy(
o, O[i_k, i_b, i * chunk_size:(i + 1) * chunk_size, i_h,
i_v * BV:(i_v + 1) * BV])
T.copy(o, O[i_k, i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV])
T.gemm(k, v, h, transpose_A=True)
return chunk_retention_fwd
......@@ -89,24 +84,24 @@ def postprocess(o):
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--B', type=int, default=8, help='Batch size')
parser.add_argument('--S', type=int, default=4096, help='Seq len')
parser.add_argument('--H', type=int, default=32, help='Num heads')
parser.add_argument('--D', type=int, default=128, help='Head dim')
parser.add_argument("--B", type=int, default=8, help="Batch size")
parser.add_argument("--S", type=int, default=4096, help="Seq len")
parser.add_argument("--H", type=int, default=32, help="Num heads")
parser.add_argument("--D", type=int, default=128, help="Head dim")
args = parser.parse_args()
B, S, H, D = args.B, args.S, args.H, args.D
total_flops = 2.0 * B * S * S * H * D # causal
q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16)
k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16)
v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16)
kernel = chunk_retention_fwd_kernel(B, S, H, D, D)
t = do_bench(lambda: postprocess(kernel(q, k, v)), warmup=25, rep=100)
print(f'Tilelang latency: {t:.3f} ms')
print(f'Tilelang TFLOPs: {total_flops/t * 1e-9}')
print(f"Tilelang latency: {t:.3f} ms")
print(f"Tilelang TFLOPs: {total_flops / t * 1e-9}")
if __name__ == '__main__':
if __name__ == "__main__":
main()
......@@ -15,12 +15,11 @@ from tilelang.profiler import do_bench
@tilelang.jit(out_idx=[3])
def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_size):
block_M = 64
block_N = 64
num_stages = 2
threads = 128
scale = (1.0 / dim)**0.5 * 1.44269504
scale = (1.0 / dim) ** 0.5 * 1.44269504
shape = [batch, heads, seq_len, dim]
seq_blocks = (seq_len + block_M - 1) // block_M
......@@ -30,15 +29,13 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
offset_shape = count_shape + [slash_size]
index_shape = count_shape + [vertical_size]
vertical_size_round, slash_size_round = tilelang.next_power_of_2(
vertical_size), tilelang.next_power_of_2(slash_size)
vertical_size_round, slash_size_round = tilelang.next_power_of_2(vertical_size), tilelang.next_power_of_2(slash_size)
dtype = "float16"
accum_dtype = "float"
int_dtype = "int32"
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def Prefetch(
K: T.Tensor(shape, dtype),
......@@ -53,13 +50,11 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
):
with T.attr("default", "async_scope", 1):
for i, j in T.Parallel(block_N, dim):
K_shared[i, j] = T.if_then_else(k + i < column_count,
K[bz, by, column_index[k + i], j], 0)
K_shared[i, j] = T.if_then_else(k + i < column_count, K[bz, by, column_index[k + i], j], 0)
with T.attr("default", "async_scope", 1):
for i, j in T.Parallel(block_N, dim):
V_shared[i, j] = T.if_then_else(k + i < column_count,
V[bz, by, column_index[k + i], j], 0)
V_shared[i, j] = T.if_then_else(k + i < column_count, V[bz, by, column_index[k + i], j], 0)
T.ptx_commit_group()
......@@ -118,7 +113,6 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
ColumnIndex: T.Tensor(index_shape, int_dtype),
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bc, by, bz):
bx = T.ceildiv(seq_len, block_M) - 1 - bc
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([2, block_N, dim], dtype)
......@@ -143,9 +137,11 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
T.create_list_of_mbarrier([128] * 9)
T.annotate_layout({
T.annotate_layout(
{
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
})
}
)
block_count[0] = BlockCount[bz, by, bx]
column_count[0] = ColumnCount[bz, by, bx]
......@@ -162,15 +158,15 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
if tid >= 128:
T.annotate_producer_reg_dealloc()
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.mbarrier_arrive(mbarrier=8)
for bi in T.serial(block_count[0]):
k = block_offset[bi]
T.mbarrier_wait_parity(mbarrier=bi % 2 + 4, parity=(((bi & 3) >> 1) ^ 1))
T.copy(K[bz, by, k:k + block_N, :], K_shared[bi % 2, :, :])
T.copy(K[bz, by, k : k + block_N, :], K_shared[bi % 2, :, :])
T.mbarrier_arrive(mbarrier=bi % 2)
T.mbarrier_wait_parity(mbarrier=bi % 2 + 6, parity=(((bi & 3) >> 1) ^ 1))
T.copy(V[bz, by, k:k + block_N, :], V_shared[bi % 2, :, :])
T.copy(V[bz, by, k : k + block_N, :], V_shared[bi % 2, :, :])
T.mbarrier_arrive(mbarrier=bi % 2 + 2)
else:
T.annotate_consumer_reg_alloc()
......@@ -181,16 +177,10 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
for bi in T.serial(block_count[0]):
k = block_offset[bi]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k + j, 0,
-T.infinity(acc_s.dtype))
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k + j, 0, -T.infinity(acc_s.dtype))
T.mbarrier_wait_parity(mbarrier=bi % 2, parity=((bi & 3) >> 1))
T.gemm(
Q_shared,
K_shared[bi % 2, :, :],
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
T.gemm(Q_shared, K_shared[bi % 2, :, :], acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.mbarrier_arrive(mbarrier=bi % 2 + 4)
T.copy(scores_max, scores_max_prev)
......@@ -200,20 +190,15 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale -
scores_max[i] * scale)
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] = acc_o[i, j] * scores_scale[i]
T.copy(acc_s, acc_s_cast)
T.mbarrier_wait_parity(mbarrier=bi % 2 + 2, parity=(((bi & 3) >> 1)))
T.gemm(
acc_s_cast,
V_shared[bi % 2, :, :],
acc_o,
policy=T.GemmWarpPolicy.FullRow)
T.mbarrier_wait_parity(mbarrier=bi % 2 + 2, parity=((bi & 3) >> 1))
T.gemm(acc_s_cast, V_shared[bi % 2, :, :], acc_o, policy=T.GemmWarpPolicy.FullRow)
T.mbarrier_arrive(mbarrier=bi % 2 + 6)
......@@ -223,38 +208,85 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
if column_count[0] != 0:
Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count[0], 0, bz,
by)
Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count[0], 0, bz, by)
for bi in T.serial(T.ceildiv(column_count[0], block_N) - 1):
k = bi * block_N
if bi % 2 == 0:
Prefetch(K, V, K_shared_2, V_shared_2, column_index,
column_count[0], k + block_N, bz, by)
Prefetch(K, V, K_shared_2, V_shared_2, column_index, column_count[0], k + block_N, bz, by)
Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, k,
column_count[0], Q_shared, K_shared_1, V_shared_1,
scores_scale, scores_sum, logsum, 1)
Compute(
acc_s,
acc_s_cast,
acc_o,
scores_max,
scores_max_prev,
k,
column_count[0],
Q_shared,
K_shared_1,
V_shared_1,
scores_scale,
scores_sum,
logsum,
1,
)
else:
Prefetch(K, V, K_shared_1, V_shared_1, column_index,
column_count[0], k + block_N, bz, by)
Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count[0], k + block_N, bz, by)
Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, k,
column_count[0], Q_shared, K_shared_2, V_shared_2,
scores_scale, scores_sum, logsum, 1)
Compute(
acc_s,
acc_s_cast,
acc_o,
scores_max,
scores_max_prev,
k,
column_count[0],
Q_shared,
K_shared_2,
V_shared_2,
scores_scale,
scores_sum,
logsum,
1,
)
if T.ceildiv(column_count[0], block_N) % 2 == 0:
Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev,
Compute(
acc_s,
acc_s_cast,
acc_o,
scores_max,
scores_max_prev,
T.ceildiv(column_count[0], block_N) * block_N - block_N,
column_count[0], Q_shared, K_shared_2, V_shared_2, scores_scale,
scores_sum, logsum, 0)
column_count[0],
Q_shared,
K_shared_2,
V_shared_2,
scores_scale,
scores_sum,
logsum,
0,
)
else:
Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev,
Compute(
acc_s,
acc_s_cast,
acc_o,
scores_max,
scores_max_prev,
T.ceildiv(column_count[0], block_N) * block_N - block_N,
column_count[0], Q_shared, K_shared_1, V_shared_1, scores_scale,
scores_sum, logsum, 0)
column_count[0],
Q_shared,
K_shared_1,
V_shared_1,
scores_scale,
scores_sum,
logsum,
0,
)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return vs_sparse_flashattn_ws
......@@ -470,11 +502,8 @@ def vertical_slash_sparse_attention(
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
sources = [
os.path.join(current_dir, 'ops', 'kernels.cpp'),
os.path.join(current_dir, 'ops', 'vertical_slash_index.cu')
]
ops = load(name='convert', sources=sources, verbose=False)
sources = [os.path.join(current_dir, "ops", "kernels.cpp"), os.path.join(current_dir, "ops", "vertical_slash_index.cu")]
ops = load(name="convert", sources=sources, verbose=False)
convert_vertical_slash_indexes = ops.convert_vertical_slash_indexes
batch_size, num_heads, context_size, head_dim = query.shape
pad = (block_size_M - context_size) & (block_size_M - 1)
......@@ -485,15 +514,13 @@ def vertical_slash_sparse_attention(
value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0])
if head_dim not in [16, 32, 64, 128, 256, 512]:
target_dim = 2**math.ceil(math.log2(head_dim)) - head_dim
target_dim = 2 ** math.ceil(math.log2(head_dim)) - head_dim
query = torch.nn.functional.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0])
key = torch.nn.functional.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0])
value = torch.nn.functional.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0])
v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(
dim=-1, descending=False)[0]
s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(
dim=-1, descending=True)[0]
v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0]
s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0]
seqlens = torch.tensor([context_size] * query.shape[0], dtype=torch.int32, device=query.device)
sm_scale = head_dim**-0.5
......@@ -506,8 +533,7 @@ def vertical_slash_sparse_attention(
block_size_N,
)
tl_kernel = _tl_vs_sparse_flashattn(batch_size, num_heads, context_size, head_dim,
v_idx.shape[2], s_idx.shape[2])
tl_kernel = _tl_vs_sparse_flashattn(batch_size, num_heads, context_size, head_dim, v_idx.shape[2], s_idx.shape[2])
def run(is_triton: bool = True):
if is_triton:
......@@ -525,8 +551,7 @@ def vertical_slash_sparse_attention(
block_size_N,
)
else:
out = tl_kernel(query, key, value, block_count, block_offset, column_count,
column_index)
out = tl_kernel(query, key, value, block_count, block_offset, column_count, column_index)
return out[..., :context_size, :head_dim]
return run
......@@ -536,8 +561,7 @@ def sum_all_diagonal_matrix(mat: torch.tensor):
b, h, n, m = mat.shape
zero_mat = torch.zeros((b, h, n, n)).to(mat.device) # Zero matrix used for padding
mat_padded = torch.cat((zero_mat, mat, zero_mat), -1) # pads the matrix on left and right
mat_strided = mat_padded.as_strided(
(1, 1, n, n + m), (1, n * (2 * n + m), 2 * n + m + 1, 1)) # Change the strides
mat_strided = mat_padded.as_strided((1, 1, n, n + m), (1, n * (2 * n + m), 2 * n + m + 1, 1)) # Change the strides
sum_diags = torch.sum(mat_strided, 2) # Sums the resulting matrix's columns
return sum_diags[:, :, 1:]
......@@ -559,24 +583,23 @@ def main(argv=None):
vertical_size, slash_size = args.vertical_size, args.slash_size
torch.manual_seed(0)
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
q_len = SEQ_LEN
vertical_size, slash_size = min(q_len, vertical_size), min(q_len, slash_size)
last_q = 64
qk = torch.einsum('bhmk, bhnk -> bhmn', q[:, :, -last_q:, :], k)
qk = torch.einsum("bhmk, bhnk -> bhmn", q[:, :, -last_q:, :], k)
arange = torch.arange(last_q, device="cuda")
qk[:, :, :, -last_q:] = torch.where(arange[None, None, :, None] >= arange[None, None, None, :],
qk[:, :, :, -last_q:], -torch.inf)
qk[:, :, :, -last_q:] = torch.where(arange[None, None, :, None] >= arange[None, None, None, :], qk[:, :, :, -last_q:], -torch.inf)
qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
vertical = qk.sum(-2, keepdim=True)
vertical[..., :30] = torch.inf
vertical_topk = torch.topk(vertical, vertical_size, -1).indices
slash = sum_all_diagonal_matrix(qk)[..., :-last_q + 1]
slash = sum_all_diagonal_matrix(qk)[..., : -last_q + 1]
slash[..., -30:] = torch.inf
slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices
......
......@@ -45,7 +45,7 @@ def rms_norm(M, N, blk_m):
A_local = T.alloc_fragment((blk_m, N), dtype)
A_powsum = T.alloc_fragment((blk_m,), dtype)
T.copy(A[bx * blk_m:(bx + 1) * blk_m, :], A_shared)
T.copy(A[bx * blk_m : (bx + 1) * blk_m, :], A_shared)
T.copy(A_shared, A_local)
for i, j in T.Parallel(blk_m, N):
A_pow_local[i, j] = A_local[i, j] * A_local[i, j]
......@@ -54,7 +54,7 @@ def rms_norm(M, N, blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12)
for i, j in T.Parallel(blk_m, N):
A_local[i, j] *= A_powsum[i]
T.copy(A_local, B[bx * blk_m:(bx + 1) * blk_m, :])
T.copy(A_local, B[bx * blk_m : (bx + 1) * blk_m, :])
return main
......
......@@ -45,7 +45,7 @@ def rms_norm(M, N, blk_m):
A_local = T.alloc_fragment((blk_m, N), dtype)
A_powsum = T.alloc_fragment((blk_m,), dtype)
T.copy(A[bx * blk_m:(bx + 1) * blk_m, :], A_shared)
T.copy(A[bx * blk_m : (bx + 1) * blk_m, :], A_shared)
T.copy(A_shared, A_local)
for i, j in T.Parallel(blk_m, N):
A_pow_local[i, j] = A_local[i, j] * A_local[i, j]
......@@ -54,7 +54,7 @@ def rms_norm(M, N, blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12)
for i, j in T.Parallel(blk_m, N):
A_local[i, j] *= A_powsum[i]
T.copy(A_local, B[bx * blk_m:(bx + 1) * blk_m, :])
T.copy(A_local, B[bx * blk_m : (bx + 1) * blk_m, :])
return main
......
......@@ -33,7 +33,7 @@ def softmax_kernel(
T.fill(lse, -T.infinity(accum_dtype))
for i_n in T.Pipelined(0, NN):
T.copy(X[i_m, i_n * BN:(i_n + 1) * BN], x)
T.copy(X[i_m, i_n * BN : (i_n + 1) * BN], x)
T.reduce_max(x, max_x, dim=0, clear=True)
......@@ -45,12 +45,12 @@ def softmax_kernel(
lse[0] = max_x[0] * scale + T.log2(T.exp2(lse[0] - max_x[0] * scale) + sum_exp_x[0])
for i_n in T.Pipelined(0, NN):
T.copy(X[i_m, i_n * BN:(i_n + 1) * BN], x)
T.copy(X[i_m, i_n * BN : (i_n + 1) * BN], x)
for j in T.Parallel(BN):
y[j] = T.exp2(x[j] * scale - lse[0])
T.copy(y, Y[i_m, i_n * BN:(i_n + 1) * BN])
T.copy(y, Y[i_m, i_n * BN : (i_n + 1) * BN])
return main
......@@ -69,4 +69,4 @@ t1 = do_bench(lambda: X.softmax(dim=1), warmup=25, rep=100)
t2 = do_bench(lambda: kernel(X), warmup=25, rep=100)
print(f"torch latency: {t1:.3f} ms")
print(f"TileLang latency: {t2:.3f} ms")
print(f"Speedup: {t1/t2:.3f}x")
print(f"Speedup: {t1 / t2:.3f}x")
......@@ -11,10 +11,9 @@ from tilelang.intrinsics.mfma_layout import (
)
def make_mfma_load_base_layout(dtype: str = "float16",
matrix: Literal["A", "B"] = "A",
k_dim: int = 16,
transposed: bool = False) -> T.Fragment:
def make_mfma_load_base_layout(
dtype: str = "float16", matrix: Literal["A", "B"] = "A", k_dim: int = 16, transposed: bool = False
) -> T.Fragment:
"""
Create a layout function for storing MFMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mfma_store_layout` to
......@@ -72,12 +71,10 @@ def make_mfma_load_base_layout(dtype: str = "float16",
# so the b matrix expected a transposed basic layout
transform_func: Callable = None
if matrix == "A":
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i)
micro_size_s, micro_size_r = micro_size_x, micro_size_k
elif matrix == "B":
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
j, i)
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i)
micro_size_s, micro_size_r = micro_size_k, micro_size_y
else:
raise ValueError(f"Unsupported matrix {matrix}")
......@@ -120,14 +117,11 @@ print(base_layout)
plot_layout(base_layout, name="base_layout")
# warp layout 32x32
warp_layout = base_layout.repeat([warp_rows, warp_cols],
repeat_on_thread=False,
lower_dim_first=False)
warp_layout = base_layout.repeat([warp_rows, warp_cols], repeat_on_thread=False, lower_dim_first=False)
print(warp_layout)
plot_layout(warp_layout, name="warp_layout")
# block layout 64x32
block_layout = warp_layout.repeat([block_rows, 1], repeat_on_thread=True,
lower_dim_first=True).replicate(block_cols)
block_layout = warp_layout.repeat([block_rows, 1], repeat_on_thread=True, lower_dim_first=True).replicate(block_cols)
print(block_layout)
plot_layout(block_layout, name="block_layout")
......@@ -5,9 +5,7 @@ from tvm.tir import IndexMap
from tilelang.intrinsics.utils import get_mma_micro_size
def make_mma_load_base_layout(dtype: str = "float16",
matrix: Literal["A", "B"] = "A",
transposed: bool = False) -> T.Fragment:
def make_mma_load_base_layout(dtype: str = "float16", matrix: Literal["A", "B"] = "A", transposed: bool = False) -> T.Fragment:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
......@@ -36,6 +34,7 @@ def make_mma_load_base_layout(dtype: str = "float16",
shared_16x16_to_mma_32x8_layout_sr_b,
shared_16x32_to_mma_32x16_layout_sr_b,
)
assert matrix in ["A", "B"], "matrix should be either A or B"
dtype_bits = DataType(dtype).bits
# s represents spatial axis
......@@ -67,12 +66,10 @@ def make_mma_load_base_layout(dtype: str = "float16",
# so the b matrix expected a transposed basic layout
transform_func: Callable = None
if matrix == "A":
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i)
micro_size_s, micro_size_r = micro_size_x, micro_size_k
elif matrix == "B":
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
j, i)
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i)
micro_size_s, micro_size_r = micro_size_k, micro_size_y
else:
raise ValueError(f"Unsupported matrix {matrix}")
......
......@@ -7,7 +7,6 @@ import tilelang.language as T
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
......
......@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
False,
dtype=torch.bool,
device=x.device)
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True
......@@ -30,15 +27,17 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
@tilelang.jit(
out_idx=[4], pass_configs={
out_idx=[4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
},
)
def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_causal):
block_M = 64
block_N = 64
num_stages = 0
threads = 128
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
block_mask_shape = [batch, heads, downsample_len, downsample_len]
......@@ -48,7 +47,6 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
block_mask_dtype = "int8"
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
......@@ -112,7 +110,7 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
logsum = T.alloc_fragment([block_M], accum_dtype)
block_mask = T.alloc_local([downsample_len], block_mask_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -124,33 +122,25 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
for k in T.Pipelined(loop_range, num_stages=num_stages):
if block_mask[k] != 0:
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
if is_causal:
past_len = seq_kv - seq_q
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
bx * block_M + i + past_len >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
acc_s[i, j] = T.if_then_else(bx * block_M + i + past_len >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main
......@@ -165,44 +155,40 @@ def test_topk_sparse_attention():
torch.manual_seed(0)
# Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level)
downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
device='cuda',
dtype=torch.float16)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.float16)
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
# Run tilelang kernel
kernel = blocksparse_flashattn(
BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
tilelang_output = kernel(q, k, v, block_mask.to(torch.int8))
# Compute reference
# Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda'))
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda"))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
# PyTorch reference implementation
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float('-inf'))
attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float("-inf"))
attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v)
ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
print("ref_output", ref_output)
print("tilelang_output", tilelang_output)
# Verify accuracy
assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), \
"TileLang output doesn't match reference"
assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), "TileLang output doesn't match reference"
print("Pass topk sparse attention test with qlen == klen")
......@@ -215,42 +201,40 @@ def test_topk_sparse_attention_qlen_lt_klen():
torch.manual_seed(0)
# Create inputs.
q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.float16)
q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.float16)
sm_scale = 1.0 / (D_HEAD**0.5)
downsample_factor = BLOCK
downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension
x_ds = torch.randn(
BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.float16)
x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.float16)
# Force the first column to be high so that the first block is always selected.
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
kernel = blocksparse_flashattn(
BATCH, N_HEADS, Q_LEN, K_LEN, D_HEAD, downsample_len, is_causal=True)
kernel = blocksparse_flashattn(BATCH, N_HEADS, Q_LEN, K_LEN, D_HEAD, downsample_len, is_causal=True)
print(kernel.get_kernel_source())
tilelang_output = kernel(q, k, v, block_mask.to(torch.int8))
past_len = K_LEN - Q_LEN
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale
attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool()
full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool()
full_mask_full = full_mask_full[..., :K_LEN, :K_LEN]
effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN)
i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1)
j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN)
causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN)
causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN)
final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN)
attn = attn.masked_fill(~final_mask, float('-inf'))
attn = attn.masked_fill(~final_mask, float("-inf"))
attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v)
ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
print("ref_output", ref_output)
print("tilelang_output", tilelang_output)
......
......@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
False,
dtype=torch.bool,
device=x.device)
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True
......@@ -54,7 +51,6 @@ def _fwd_kernel_inner(
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n)
if mask_val == True:
......@@ -69,7 +65,7 @@ def _fwd_kernel_inner(
qk *= sm_scale
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float('-inf'))
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf"))
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
......@@ -149,7 +145,7 @@ def _fwd_kernel(
v_ptrs = V + off_v
mask_ptrs = block_mask_ptr + start_m * stride_bmm
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
......@@ -185,24 +181,12 @@ def _fwd_kernel(
acc = acc * l_recip
acc = acc.to(Out.dtype.element_ty)
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[
None, :] * stride_od
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX)
def _forward(ctx,
q,
k,
v,
block_sparse_mask,
sm_scale,
BLOCK_M=64,
BLOCK_N=64,
num_warps=None,
num_stages=1,
out=None):
def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None):
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
assert k.shape[2] == v.shape[2]
o = out if out is not None else torch.empty_like(q).contiguous()
......@@ -247,7 +231,6 @@ def _forward(ctx,
class _sparse_attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, block_sparse_dense, sm_scale):
# shape constraints
......@@ -271,9 +254,9 @@ def test_topk_sparse_attention():
torch.manual_seed(0)
# Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level)
......@@ -281,9 +264,7 @@ def test_topk_sparse_attention():
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
print("downsample_len", downsample_len)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
device='cuda',
dtype=torch.bfloat16)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100
print("x_ds.shape", x_ds.shape)
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
......@@ -295,22 +276,21 @@ def test_topk_sparse_attention():
# Compute reference
# Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda'))
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda"))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
# PyTorch reference implementation
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float('-inf'))
attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float("-inf"))
attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v)
ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
# print("ref_output", ref_output)
# print("triton_output", triton_output)
# Verify accuracy
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \
"Triton output doesn't match reference"
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference"
print("Pass topk sparse attention test with qlen == klen")
......@@ -322,16 +302,15 @@ def test_topk_sparse_attention_qlt_kl():
torch.manual_seed(0)
# Create inputs.
q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
# softmax scale
sm_scale = 1.0 / (D_HEAD**0.5)
downsample_factor = BLOCK
downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension
x_ds = torch.randn(
BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.bfloat16)
x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.bfloat16)
# Force the first column to be high so that the first block is always selected.
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
......@@ -340,26 +319,25 @@ def test_topk_sparse_attention_qlt_kl():
past_len = K_LEN - Q_LEN
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale
attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool()
full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool()
full_mask_full = full_mask_full[..., :K_LEN, :K_LEN]
effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN)
i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1)
j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN)
causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN)
causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN)
final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN)
attn = attn.masked_fill(~final_mask, float('-inf'))
attn = attn.masked_fill(~final_mask, float("-inf"))
attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v)
ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
# Verify accuracy.
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \
"Triton output doesn't match reference when qlen < klen"
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference when qlen < klen"
print("Pass topk sparse attention test with qlen < klen")
......
......@@ -29,23 +29,21 @@ def matmul_sp(
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // 8), 'uint8'),
E: T.Tensor((M, K // 8), "uint8"),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // 8), 'uint8')
E_shared = T.alloc_shared((block_M, block_K // 8), "uint8")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({
E:
make_cutlass_metadata_layout(
E, mma_dtype="float16", arch="9.0", block_k=block_K),
E_shared:
make_cutlass_metadata_layout(
E_shared, mma_dtype="float16", arch="9.0", block_k=block_K),
})
T.annotate_layout(
{
E: make_cutlass_metadata_layout(E, mma_dtype="float16", arch="9.0", block_k=block_K),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", arch="9.0", block_k=block_K),
}
)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // 8], E_shared)
......@@ -57,7 +55,7 @@ def matmul_sp(
return main
def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device='cpu'):
def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device="cpu"):
if shape[-1] % 4 != 0:
raise ValueError("Last dimension must be divisible by 4 for 2:4 sparsity.")
......@@ -102,9 +100,9 @@ def run_gemm_sp(
num_threads,
)
A = generate_2_to_4_sparse_tensor((M, K), dtype=torch.float16, device='cuda')
A = generate_2_to_4_sparse_tensor((M, K), dtype=torch.float16, device="cuda")
A_sparse, E = compress_sm90(A, block_k=block_K, transposed=False)
B = torch.randn((K, N), device='cuda', dtype=torch.float16)
B = torch.randn((K, N), device="cuda", dtype=torch.float16)
C_sp = kernel(A_sparse, E, B).half()
C = torch.matmul(A, B)
......
......@@ -43,15 +43,12 @@ def tl_topk(
T.reduce_max(logits_frag, max_val, dim=1, clear=True)
for i, j in T.Parallel(blk_m, N):
expand_max_idx[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], j,
expand_max_idx[i, j])
expand_max_idx[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], j, expand_max_idx[i, j])
T.reduce_max(expand_max_idx, max_idx, dim=1, clear=True)
for i, j in T.Parallel(blk_m, N):
logits_frag[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], -10000.0,
logits_frag[i, j])
logits_frag[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], -10000.0, logits_frag[i, j])
for i in T.Parallel(blk_m):
topk_gates[bx * blk_m + i, k] = max_val[i]
......@@ -61,7 +58,6 @@ def tl_topk(
def ref_program(logits, top_k):
top_k_gates, top_k_indices = logits.topk(top_k, dim=1)
return top_k_gates, top_k_indices.to(torch.int32)
......
......@@ -7,10 +7,10 @@ import tilelang.language as T
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE: True,
tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS: "svg"
})
tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS: "svg",
},
)
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def gemm(
A: T.Tensor((M, K), dtype),
......@@ -49,12 +49,12 @@ def main():
print("All check passed.")
# print the layout visualization result and save figures to ./tmp.
'''
"""
C_local inferenced layout:
Shape: [32, 32] -> [8]
Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2
Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2]
'''
"""
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment