Unverified Commit 8a5eb569 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Refactor] Use forceinline in `ldmatrix` and update mamba scan kernel (#1104)

parent 5683e6a6
...@@ -71,7 +71,12 @@ def get_configs(): ...@@ -71,7 +71,12 @@ def get_configs():
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[7]) @tilelang.jit(
out_idx=[7],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def chunk_scan_fwd(batch, def chunk_scan_fwd(batch,
seqlen, seqlen,
chunk_size, chunk_size,
...@@ -91,13 +96,16 @@ def chunk_scan_fwd(batch, ...@@ -91,13 +96,16 @@ def chunk_scan_fwd(batch,
p = 1.44269504 p = 1.44269504
@T.prim_func @T.prim_func
def main(cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), x: T.Tensor( def main(
(batch, seqlen, nheads, headdim), dtype), dt: T.Tensor( cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore
(batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor( x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore
(batch, nheads, nchunks, chunk_size), dtype), dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), prev_states: T.Tensor( dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
(batch, nchunks, nheads, headdim, dstate), dtype), D: T.Tensor( C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore
(nheads), dtype), Output: T.Tensor((batch, seqlen, nheads, headdim), dtype)): prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore
D: T.Tensor((nheads), dtype), # type: ignore
Output: T.Tensor((batch, seqlen, nheads, headdim), dtype) # type: ignore
):
with T.Kernel( with T.Kernel(
nheads, nheads,
T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N),
...@@ -134,6 +142,8 @@ def chunk_scan_fwd(batch, ...@@ -134,6 +142,8 @@ def chunk_scan_fwd(batch,
x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared) x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared)
}) })
T.no_set_max_nreg()
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M], T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M],
dA_cs_m_shared) dA_cs_m_shared)
T.copy(dA_cs_m_shared, dA_cs_m_local) T.copy(dA_cs_m_shared, dA_cs_m_local)
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
namespace tl { namespace tl {
TL_DEVICE_NOINLINE void ptx_ldmatrix_x1(void const *const smem_ptr, TL_DEVICE void ptx_ldmatrix_x1(void const *const smem_ptr,
void *const local_ptr) { void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr); int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
...@@ -13,8 +13,8 @@ TL_DEVICE_NOINLINE void ptx_ldmatrix_x1(void const *const smem_ptr, ...@@ -13,8 +13,8 @@ TL_DEVICE_NOINLINE void ptx_ldmatrix_x1(void const *const smem_ptr,
: "r"(smem_int_ptr)); : "r"(smem_int_ptr));
} }
TL_DEVICE_NOINLINE void ptx_ldmatrix_x2(void const *const smem_ptr, TL_DEVICE void ptx_ldmatrix_x2(void const *const smem_ptr,
void *const local_ptr) { void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr); int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
...@@ -22,8 +22,8 @@ TL_DEVICE_NOINLINE void ptx_ldmatrix_x2(void const *const smem_ptr, ...@@ -22,8 +22,8 @@ TL_DEVICE_NOINLINE void ptx_ldmatrix_x2(void const *const smem_ptr,
: "r"(smem_int_ptr)); : "r"(smem_int_ptr));
} }
TL_DEVICE_NOINLINE void ptx_ldmatrix_x4(void const *const smem_ptr, TL_DEVICE void ptx_ldmatrix_x4(void const *const smem_ptr,
void *const local_ptr) { void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr); int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile( asm volatile(
...@@ -32,8 +32,8 @@ TL_DEVICE_NOINLINE void ptx_ldmatrix_x4(void const *const smem_ptr, ...@@ -32,8 +32,8 @@ TL_DEVICE_NOINLINE void ptx_ldmatrix_x4(void const *const smem_ptr,
: "r"(smem_int_ptr)); : "r"(smem_int_ptr));
} }
TL_DEVICE_NOINLINE void ptx_ldmatrix_x1_trans(void const *const smem_ptr, TL_DEVICE void ptx_ldmatrix_x1_trans(void const *const smem_ptr,
void *const local_ptr) { void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr); int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n"
...@@ -41,8 +41,8 @@ TL_DEVICE_NOINLINE void ptx_ldmatrix_x1_trans(void const *const smem_ptr, ...@@ -41,8 +41,8 @@ TL_DEVICE_NOINLINE void ptx_ldmatrix_x1_trans(void const *const smem_ptr,
: "r"(smem_int_ptr)); : "r"(smem_int_ptr));
} }
TL_DEVICE_NOINLINE void ptx_ldmatrix_x2_trans(void const *const smem_ptr, TL_DEVICE void ptx_ldmatrix_x2_trans(void const *const smem_ptr,
void *const local_ptr) { void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr); int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile( asm volatile(
...@@ -51,8 +51,8 @@ TL_DEVICE_NOINLINE void ptx_ldmatrix_x2_trans(void const *const smem_ptr, ...@@ -51,8 +51,8 @@ TL_DEVICE_NOINLINE void ptx_ldmatrix_x2_trans(void const *const smem_ptr,
: "r"(smem_int_ptr)); : "r"(smem_int_ptr));
} }
TL_DEVICE_NOINLINE void ptx_ldmatrix_x4_trans(void const *const smem_ptr, TL_DEVICE void ptx_ldmatrix_x4_trans(void const *const smem_ptr,
void *const local_ptr) { void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr); int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile( asm volatile(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment