"src/include/threadwise_direct_convolution.hpp" did not exist on "216e3da60959ee5968d7424ac0943c86fbf55375"
Unverified Commit eb6e8973 authored by Zhengju Tang's avatar Zhengju Tang Committed by GitHub
Browse files

[GQA] Add varlen decoding kernel with logits saving (#1223)

* [Example] Add GQA varlen decoding kernel with logits return

* [Example] Support Sink for GQA varlen decoding

* [Example] Add for no-varlen support

* [Tune] Add high performance logits saving

* [Lint]

* [Lint]

* [Rename]
parent 47039f06
...@@ -40,9 +40,9 @@ def get_heuristic_config() -> Tuple[Dict, int]: ...@@ -40,9 +40,9 @@ def get_heuristic_config() -> Tuple[Dict, int]:
sm_version = sm_major * 10 + sm_minor sm_version = sm_major * 10 + sm_minor
print(f"CUDA device capability: {sm_version}") print(f"CUDA device capability: {sm_version}")
if sm_version == 89: if sm_version == 89:
cfg = dict(block_N=128, block_H=64, num_split=16, num_stages=0, threads=128) cfg = dict(block_N=128, block_H=64, num_split=1, num_stages=0, threads=128)
else: else:
cfg = dict(block_N=128, block_H=64, num_split=16, num_stages=2, threads=128) cfg = dict(block_N=128, block_H=64, num_split=1, num_stages=2, threads=128)
return cfg, sm_version return cfg, sm_version
...@@ -459,8 +459,9 @@ def main(batch: int = 1, ...@@ -459,8 +459,9 @@ def main(batch: int = 1,
k = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16) k = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16)
v = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16) v = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16)
mask = torch.randint(0, 2, (batch, kv_seqlen, groups), device="cuda", dtype=torch.uint8) mask = torch.randint(0, 2, (batch, kv_seqlen, groups), device="cuda", dtype=torch.uint8)
glse = torch.empty(batch, heads, 16, device="cuda", dtype=torch.float16) split = config["num_split"]
Output_partial = torch.empty(batch, heads, 16, dim, device="cuda", dtype=torch.float16) glse = torch.empty(batch, heads, split, device="cuda", dtype=torch.float16)
Output_partial = torch.empty(batch, heads, split, dim, device="cuda", dtype=torch.float16)
o = kernel(q, k, v, mask, glse, Output_partial) o = kernel(q, k, v, mask, glse, Output_partial)
o_ref = ref_program(q, k, v, mask, glse, Output_partial) o_ref = ref_program(q, k, v, mask, glse, Output_partial)
o_ref_split = ref_split_program(q, k, v, mask, glse, Output_partial) o_ref_split = ref_split_program(q, k, v, mask, glse, Output_partial)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment