Commit 8fdfdf03 authored by Yuqing Xia's avatar Yuqing Xia Committed by LeiWang1999
Browse files

[Example] Add sparse gqa decode example (#332)



* add example gqa decode wgmma pipelined

* add sparse gqa

* support num split

* support num split

* add if condition

* add heuristic num split

* clean code

* add ref

* fix bug

* add torch ref

* fix bug

* integrate to torch

* symbolic

* clean mask

* rm actual_num_blocks

* clean code

* get num_sm via torch

* add sparse gqa decode example

* format

* rm example_gqa_decode_wgmma_pipelined.py

* Add license headers to example scripts

* format

* Remove commented-out cache disabling lines

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent eb757608
...@@ -21,7 +21,6 @@ def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, siz ...@@ -21,7 +21,6 @@ def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, siz
# If we have enough m_blocks to almost fill the SMs, prefer 1 split unless memory constraints apply. # If we have enough m_blocks to almost fill the SMs, prefer 1 split unless memory constraints apply.
if total_mblocks >= 0.8 * num_SMs: if total_mblocks >= 0.8 * num_SMs:
size_l2 = 50 * 1024 * 1024 # L2 cache size assumption (50MB) size_l2 = 50 * 1024 * 1024 # L2 cache size assumption (50MB)
# Only split if each KV head is too large for L2 and there are enough m_blocks # Only split if each KV head is too large for L2 and there are enough m_blocks
if size_one_kv_head > size_l2 and num_m_blocks >= num_SMs * 2 and not is_causal_or_local: if size_one_kv_head > size_l2 and num_m_blocks >= num_SMs * 2 and not is_causal_or_local:
return min((size_one_kv_head + size_l2 - 1) // size_l2, max_splits) return min((size_one_kv_head + size_l2 - 1) // size_l2, max_splits)
...@@ -42,7 +41,6 @@ def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, siz ...@@ -42,7 +41,6 @@ def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, siz
for num_splits in range(1, max_splits + 1): for num_splits in range(1, max_splits + 1):
n_waves = (total_mblocks * num_splits) / num_SMs n_waves = (total_mblocks * num_splits) / num_SMs
eff = n_waves / math.ceil(n_waves) eff = n_waves / math.ceil(n_waves)
# Track max efficiency # Track max efficiency
if eff > max_efficiency: if eff > max_efficiency:
max_efficiency = eff max_efficiency = eff
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment