"vscode:/vscode.git/clone" did not exist on "be3dfa5049d588d8e1e5dc7c256b45dbcb3af8d4"
Unverified Commit 67c424cc authored by HAI's avatar HAI Committed by GitHub
Browse files

[Performance, Triton Kernel Args] extend_attention, optimize kern args to _fwd_kernel (#1941)

parent 1ae270c5
......@@ -25,6 +25,7 @@ import triton.language as tl
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd,
)
from sglang.srt.utils import is_hip
is_cuda_available = torch.cuda.is_available()
if is_cuda_available:
......@@ -311,6 +312,10 @@ def extend_attention_fwd(
num_warps = 4 if Lk <= 64 else 8
num_stages = 1
extra_kargs = {}
if is_hip():
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
_fwd_kernel[grid](
q_extend,
k_extend,
......@@ -348,6 +353,7 @@ def extend_attention_fwd(
Lv=Lv,
num_warps=num_warps,
num_stages=num_stages,
**extra_kargs,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment