Unverified Commit 2d4ce1b7 authored by HAI's avatar HAI Committed by GitHub
Browse files

[Performance, Triton Kernel Args] _decode_grouped_softmax_reducev_fwd… (#1845)

parent 4ba815b8
...@@ -24,6 +24,8 @@ It supports page size = 1. ...@@ -24,6 +24,8 @@ It supports page size = 1.
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.utils import is_hip
@triton.jit @triton.jit
def tanh(x): def tanh(x):
...@@ -553,6 +555,12 @@ def _decode_grouped_softmax_reducev_fwd( ...@@ -553,6 +555,12 @@ def _decode_grouped_softmax_reducev_fwd(
Lv = v_buffer.shape[-1] Lv = v_buffer.shape[-1]
BLOCK_DMODEL = triton.next_power_of_2(Lv) BLOCK_DMODEL = triton.next_power_of_2(Lv)
extra_kargs = {}
if is_hip():
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
_fwd_grouped_kernel_stage2[grid]( _fwd_grouped_kernel_stage2[grid](
logits, logits,
v_buffer, v_buffer,
...@@ -575,6 +583,7 @@ def _decode_grouped_softmax_reducev_fwd( ...@@ -575,6 +583,7 @@ def _decode_grouped_softmax_reducev_fwd(
Lv=Lv, Lv=Lv,
num_warps=num_warps, num_warps=num_warps,
num_stages=1, num_stages=1,
**extra_kargs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment