Unverified Commit d9eb9358 authored by Wen-Heng (Jack) Chung's avatar Wen-Heng (Jack) Chung Committed by GitHub
Browse files

Tune paged attention parameters for AMD GPU. (#3255)

parent 959dca4f
...@@ -181,6 +181,9 @@ def _decode_att_m_fwd( ...@@ -181,6 +181,9 @@ def _decode_att_m_fwd(
logit_cap, logit_cap,
): ):
BLOCK = 64 BLOCK = 64
# [TODO] work around SGPR limit on MI3xx
if is_hip_:
BLOCK = 8
NUM_KV_SPLITS = num_kv_splits NUM_KV_SPLITS = num_kv_splits
Lk = k_buffer.shape[-1] Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1] Lv = v_buffer.shape[-1]
...@@ -194,6 +197,8 @@ def _decode_att_m_fwd( ...@@ -194,6 +197,8 @@ def _decode_att_m_fwd(
num_warps = 4 num_warps = 4
else: else:
num_warps = 2 num_warps = 2
if is_hip_:
num_warps = 1
BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DMODEL = triton.next_power_of_2(Lk)
BLOCK_DV = triton.next_power_of_2(Lv) BLOCK_DV = triton.next_power_of_2(Lv)
...@@ -433,10 +438,12 @@ def _decode_grouped_att_m_fwd( ...@@ -433,10 +438,12 @@ def _decode_grouped_att_m_fwd(
) )
extra_kargs = {} extra_kargs = {}
num_stages = 2
if is_hip_: if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
num_stages = 1
_fwd_grouped_kernel_stage1[grid]( _fwd_grouped_kernel_stage1[grid](
q, q,
...@@ -467,7 +474,7 @@ def _decode_grouped_att_m_fwd( ...@@ -467,7 +474,7 @@ def _decode_grouped_att_m_fwd(
NUM_KV_SPLITS=NUM_KV_SPLITS, NUM_KV_SPLITS=NUM_KV_SPLITS,
logit_cap=logit_cap, logit_cap=logit_cap,
num_warps=4, num_warps=4,
num_stages=2, num_stages=num_stages,
Lk=Lk, Lk=Lk,
Lv=Lv, Lv=Lv,
**extra_kargs, **extra_kargs,
......
...@@ -273,6 +273,10 @@ class ServerArgs: ...@@ -273,6 +273,10 @@ class ServerArgs:
) and check_gguf_file(self.model_path): ) and check_gguf_file(self.model_path):
self.quantization = self.load_format = "gguf" self.quantization = self.load_format = "gguf"
# AMD-specific Triton attention KV splits default number
if is_hip():
self.triton_attention_num_kv_splits = 16
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
# Model and port args # Model and port args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment