Unverified Commit 67c424cc authored by HAI's avatar HAI Committed by GitHub
Browse files

[Performance, Triton Kernel Args] extend_attention, optimize kern args to _fwd_kernel (#1941)

parent 1ae270c5
...@@ -25,6 +25,7 @@ import triton.language as tl ...@@ -25,6 +25,7 @@ import triton.language as tl
from sglang.srt.layers.attention.triton_ops.prefill_attention import ( from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd, context_attention_fwd,
) )
from sglang.srt.utils import is_hip
is_cuda_available = torch.cuda.is_available() is_cuda_available = torch.cuda.is_available()
if is_cuda_available: if is_cuda_available:
...@@ -311,6 +312,10 @@ def extend_attention_fwd( ...@@ -311,6 +312,10 @@ def extend_attention_fwd(
num_warps = 4 if Lk <= 64 else 8 num_warps = 4 if Lk <= 64 else 8
num_stages = 1 num_stages = 1
extra_kargs = {}
if is_hip():
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
_fwd_kernel[grid]( _fwd_kernel[grid](
q_extend, q_extend,
k_extend, k_extend,
...@@ -348,6 +353,7 @@ def extend_attention_fwd( ...@@ -348,6 +353,7 @@ def extend_attention_fwd(
Lv=Lv, Lv=Lv,
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
**extra_kargs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment