"vscode:/vscode.git/clone" did not exist on "dae1c62424b96f24d2c41fb6996d75cfe7e8d21a"
Commit 6d44c465 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Typo] Replace `kernel.func` with `kernel` in mla benchmark scripts (#354)

* [Refactor] Update import structure in benchmark_mla.py

- Moved the import of `flash_mla` functions to the `run_flash_mla` function for better encapsulation.
- Added a comment for `flashinfer` installation to clarify dependencies.
- Cleaned up unused imports to enhance code readability.

* lint fix
parent 7aa34977
......@@ -3,15 +3,10 @@
import argparse
import math
import random
import flashinfer
import torch
import triton
import triton.language as tl
# pip install flashinfer-python
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
import tilelang
from tilelang.profiler import do_bench
from example_mla_decode_paged import mla_decode_tilelang
......@@ -68,6 +63,8 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
@torch.inference_mode()
def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q,
h_kv, d, dv, causal, dtype):
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
......@@ -92,7 +89,8 @@ def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
@torch.inference_mode()
def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens,
h_q, h_kv, d, dv, causal, dtype):
# pip install flashinfer-python
import flashinfer
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[...,
......@@ -441,7 +439,7 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size
kernel = tilelang.compile(program, out_idx=[8])
def flash_mla_tilelang():
out = kernel.func(
out = kernel(
q_nope.view(-1, h_q, dv),
q_pe.view(-1, h_q, dpe),
blocked_k_nope.view(-1, h_kv, dv),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment