Commit b0122d74 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Update barrier functions and remove argparse in...

[Refactor] Update barrier functions and remove argparse in example_warp_specialize_flashmla.py (#457)

* Refactored barrier functions to use new signatures for improved clarity and consistency.
* Replaced `mbarrier_arrive` and `mbarrier_wait_parity` with `barrier_arrive` and `barrier_wait` respectively.
* Removed argparse dependency and replaced it with hardcoded parameters for batch size and dimensions in the main function, simplifying the example script.
parent a91bc2a9
......@@ -63,6 +63,12 @@ jobs:
source tilelang_ci/bin/activate
python -m pip install .
- name: Run examples
run: |
source tilelang_ci/bin/activate
cd examples
python -m pytest **/test*.py
- name: Run tests
run: |
source tilelang_ci/bin/activate
......
......@@ -5,7 +5,6 @@ import torch.nn.functional as F
import tilelang
import tilelang.language as T
from einops import rearrange, einsum
import argparse
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
......@@ -159,15 +158,13 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
batch = 128
heads = 128
kv_heads = 1
kv_ctx = 8192
dim = 512
pe_dim = 64
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
pv_flops = 2 * batch * heads * kv_ctx * dim
total_flops = qk_flops + pv_flops
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment