Unverified Commit 8a5eb569 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Refactor] Use forceinline in `ldmatrix` and update mamba scan kernel (#1104)

parent 5683e6a6
......@@ -71,7 +71,12 @@ def get_configs():
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[7])
@tilelang.jit(
out_idx=[7],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def chunk_scan_fwd(batch,
seqlen,
chunk_size,
......@@ -91,13 +96,16 @@ def chunk_scan_fwd(batch,
p = 1.44269504
@T.prim_func
def main(cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), x: T.Tensor(
(batch, seqlen, nheads, headdim), dtype), dt: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype),
C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), prev_states: T.Tensor(
(batch, nchunks, nheads, headdim, dstate), dtype), D: T.Tensor(
(nheads), dtype), Output: T.Tensor((batch, seqlen, nheads, headdim), dtype)):
def main(
cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore
x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore
dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore
prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore
D: T.Tensor((nheads), dtype), # type: ignore
Output: T.Tensor((batch, seqlen, nheads, headdim), dtype) # type: ignore
):
with T.Kernel(
nheads,
T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N),
......@@ -134,6 +142,8 @@ def chunk_scan_fwd(batch,
x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared)
})
T.no_set_max_nreg()
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M],
dA_cs_m_shared)
T.copy(dA_cs_m_shared, dA_cs_m_local)
......
......@@ -4,8 +4,8 @@
namespace tl {
TL_DEVICE_NOINLINE void ptx_ldmatrix_x1(void const *const smem_ptr,
void *const local_ptr) {
TL_DEVICE void ptx_ldmatrix_x1(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
......@@ -13,8 +13,8 @@ TL_DEVICE_NOINLINE void ptx_ldmatrix_x1(void const *const smem_ptr,
: "r"(smem_int_ptr));
}
TL_DEVICE_NOINLINE void ptx_ldmatrix_x2(void const *const smem_ptr,
void *const local_ptr) {
TL_DEVICE void ptx_ldmatrix_x2(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
......@@ -22,8 +22,8 @@ TL_DEVICE_NOINLINE void ptx_ldmatrix_x2(void const *const smem_ptr,
: "r"(smem_int_ptr));
}
TL_DEVICE_NOINLINE void ptx_ldmatrix_x4(void const *const smem_ptr,
void *const local_ptr) {
TL_DEVICE void ptx_ldmatrix_x4(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile(
......@@ -32,8 +32,8 @@ TL_DEVICE_NOINLINE void ptx_ldmatrix_x4(void const *const smem_ptr,
: "r"(smem_int_ptr));
}
TL_DEVICE_NOINLINE void ptx_ldmatrix_x1_trans(void const *const smem_ptr,
void *const local_ptr) {
TL_DEVICE void ptx_ldmatrix_x1_trans(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n"
......@@ -41,8 +41,8 @@ TL_DEVICE_NOINLINE void ptx_ldmatrix_x1_trans(void const *const smem_ptr,
: "r"(smem_int_ptr));
}
TL_DEVICE_NOINLINE void ptx_ldmatrix_x2_trans(void const *const smem_ptr,
void *const local_ptr) {
TL_DEVICE void ptx_ldmatrix_x2_trans(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile(
......@@ -51,8 +51,8 @@ TL_DEVICE_NOINLINE void ptx_ldmatrix_x2_trans(void const *const smem_ptr,
: "r"(smem_int_ptr));
}
TL_DEVICE_NOINLINE void ptx_ldmatrix_x4_trans(void const *const smem_ptr,
void *const local_ptr) {
TL_DEVICE void ptx_ldmatrix_x4_trans(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment