Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
d9eb9358
Unverified
Commit
d9eb9358
authored
Feb 01, 2025
by
Wen-Heng (Jack) Chung
Committed by
GitHub
Feb 01, 2025
Browse files
Tune paged attention parameters for AMD GPU. (#3255)
parent
959dca4f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
2 deletions
+13
-2
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+9
-2
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+4
-0
No files found.
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
d9eb9358
...
@@ -181,6 +181,9 @@ def _decode_att_m_fwd(
...
@@ -181,6 +181,9 @@ def _decode_att_m_fwd(
logit_cap
,
logit_cap
,
):
):
BLOCK
=
64
BLOCK
=
64
# [TODO] work around SGPR limit on MI3xx
if
is_hip_
:
BLOCK
=
8
NUM_KV_SPLITS
=
num_kv_splits
NUM_KV_SPLITS
=
num_kv_splits
Lk
=
k_buffer
.
shape
[
-
1
]
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
...
@@ -194,6 +197,8 @@ def _decode_att_m_fwd(
...
@@ -194,6 +197,8 @@ def _decode_att_m_fwd(
num_warps
=
4
num_warps
=
4
else
:
else
:
num_warps
=
2
num_warps
=
2
if
is_hip_
:
num_warps
=
1
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
...
@@ -433,10 +438,12 @@ def _decode_grouped_att_m_fwd(
...
@@ -433,10 +438,12 @@ def _decode_grouped_att_m_fwd(
)
)
extra_kargs
=
{}
extra_kargs
=
{}
num_stages
=
2
if
is_hip_
:
if
is_hip_
:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs
=
{
"waves_per_eu"
:
4
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
extra_kargs
=
{
"waves_per_eu"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
num_stages
=
1
_fwd_grouped_kernel_stage1
[
grid
](
_fwd_grouped_kernel_stage1
[
grid
](
q
,
q
,
...
@@ -467,7 +474,7 @@ def _decode_grouped_att_m_fwd(
...
@@ -467,7 +474,7 @@ def _decode_grouped_att_m_fwd(
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
num_warps
=
4
,
num_warps
=
4
,
num_stages
=
2
,
num_stages
=
num_stages
,
Lk
=
Lk
,
Lk
=
Lk
,
Lv
=
Lv
,
Lv
=
Lv
,
**
extra_kargs
,
**
extra_kargs
,
...
...
python/sglang/srt/server_args.py
View file @
d9eb9358
...
@@ -273,6 +273,10 @@ class ServerArgs:
...
@@ -273,6 +273,10 @@ class ServerArgs:
)
and
check_gguf_file
(
self
.
model_path
):
)
and
check_gguf_file
(
self
.
model_path
):
self
.
quantization
=
self
.
load_format
=
"gguf"
self
.
quantization
=
self
.
load_format
=
"gguf"
# AMD-specific Triton attention KV splits default number
if
is_hip
():
self
.
triton_attention_num_kv_splits
=
16
@
staticmethod
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
# Model and port args
# Model and port args
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment