Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5c0b38f3
Unverified
Commit
5c0b38f3
authored
May 20, 2025
by
HAI
Committed by
GitHub
May 20, 2025
Browse files
aiter attention-backend (default enabled on AMD/ROCm) (#6381)
parent
30ca18f4
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
552 additions
and
23 deletions
+552
-23
.github/workflows/pr-test-amd.yml
.github/workflows/pr-test-amd.yml
+19
-19
docker/Dockerfile.rocm
docker/Dockerfile.rocm
+1
-1
python/sglang/srt/layers/attention/aiter_backend.py
python/sglang/srt/layers/attention/aiter_backend.py
+513
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+9
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-0
scripts/amd_ci_install_dependency.sh
scripts/amd_ci_install_dependency.sh
+0
-1
scripts/amd_ci_start_container.sh
scripts/amd_ci_start_container.sh
+1
-1
test/srt/models/test_dummy_grok_models.py
test/srt/models/test_dummy_grok_models.py
+1
-0
test/srt/test_eval_accuracy_large.py
test/srt/test_eval_accuracy_large.py
+7
-0
No files found.
.github/workflows/pr-test-amd.yml
View file @
5c0b38f3
...
...
@@ -44,7 +44,7 @@ jobs:
run
:
bash scripts/amd_ci_install_dependency.sh
-
name
:
Evaluate Accuracy
timeout-minutes
:
2
0
timeout-minutes
:
3
0
run
:
|
bash scripts/amd_ci_exec.sh python3 test_eval_accuracy_large.py
bash scripts/amd_ci_exec.sh python3 test_eval_fp8_accuracy.py
...
...
@@ -70,7 +70,7 @@ jobs:
run
:
bash scripts/amd_ci_install_dependency.sh
-
name
:
Evaluate accuracy (TP=2)
timeout-minutes
:
2
0
timeout-minutes
:
3
0
run
:
|
bash scripts/amd_ci_exec.sh python3 test_moe_eval_accuracy_large.py
...
...
@@ -94,7 +94,7 @@ jobs:
run
:
bash scripts/amd_ci_install_dependency.sh
-
name
:
MLA TEST
timeout-minutes
:
2
0
timeout-minutes
:
3
0
run
:
|
bash scripts/amd_ci_exec.sh python3 test_mla.py
...
...
@@ -118,28 +118,28 @@ jobs:
run
:
bash scripts/amd_ci_install_dependency.sh
-
name
:
Benchmark single latency
timeout-minutes
:
1
0
timeout-minutes
:
2
0
run
:
|
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_bs1_small
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_bs1_default
-
name
:
Benchmark online latency
timeout-minutes
:
1
0
timeout-minutes
:
1
5
run
:
|
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_online_latency_default
-
name
:
Benchmark offline throughput
timeout-minutes
:
1
0
timeout-minutes
:
1
5
run
:
|
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default
-
name
:
Benchmark offline throughput (Non-streaming, small batch size)
timeout-minutes
:
1
0
timeout-minutes
:
1
5
run
:
|
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
-
name
:
Benchmark online latency (EAGLE)
timeout-minutes
:
1
0
timeout-minutes
:
1
5
run
:
|
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_online_latency_eagle
...
...
@@ -163,17 +163,17 @@ jobs:
run
:
bash scripts/amd_ci_install_dependency.sh
-
name
:
Benchmark offline throughput (w/o RadixAttention)
timeout-minutes
:
1
0
timeout-minutes
:
1
5
run
:
|
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_without_radix_cache
-
name
:
Benchmark offline throughput (w/ Triton)
timeout-minutes
:
1
0
timeout-minutes
:
1
5
run
:
|
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_with_triton_attention_backend
-
name
:
Benchmark offline throughput (w/ FP8)
timeout-minutes
:
1
0
timeout-minutes
:
1
5
run
:
|
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default_fp8
...
...
@@ -197,27 +197,27 @@ jobs:
run
:
bash scripts/amd_ci_install_dependency.sh
-
name
:
Benchmark dummy grok (TP=2)
timeout-minutes
:
2
0
timeout-minutes
:
3
0
run
:
|
bash scripts/amd_ci_exec.sh python3 models/test_dummy_grok_models.py
-
name
:
Benchmark single latency (TP=2)
timeout-minutes
:
2
0
timeout-minutes
:
2
5
run
:
|
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_tp2_bs1
-
name
:
Benchmark single latency + torch.compile (TP=2)
timeout-minutes
:
2
0
timeout-minutes
:
2
5
run
:
|
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_torch_compile_tp2_bs1
-
name
:
Benchmark offline throughput (TP=2)
timeout-minutes
:
2
0
timeout-minutes
:
2
5
run
:
|
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_default
-
name
:
Benchmark offline throughput (w/o RadixAttention) (TP=2)
timeout-minutes
:
2
0
timeout-minutes
:
2
5
run
:
|
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_without_radix_cache
...
...
@@ -241,7 +241,7 @@ jobs:
run
:
bash scripts/amd_ci_install_dependency.sh
-
name
:
Run test
timeout-minutes
:
3
0
timeout-minutes
:
4
0
run
:
|
bash scripts/amd_ci_exec.sh python3 run_suite.py --suite per-commit-amd
...
...
@@ -265,7 +265,7 @@ jobs:
run
:
bash scripts/amd_ci_install_dependency.sh
-
name
:
Run test
timeout-minutes
:
3
0
timeout-minutes
:
4
0
run
:
|
bash scripts/amd_ci_exec.sh python3 run_suite.py --suite per-commit-2-gpu-amd
...
...
@@ -289,7 +289,7 @@ jobs:
run
:
bash scripts/amd_ci_install_dependency.sh
-
name
:
Run test
timeout-minutes
:
3
0
timeout-minutes
:
4
0
run
:
|
bash scripts/amd_ci_exec.sh python3 run_suite.py --suite per-commit-8-gpu-amd
...
...
docker/Dockerfile.rocm
View file @
5c0b38f3
...
...
@@ -18,7 +18,7 @@ ARG TRITON_COMMIT="improve_fa_decode_3.0.0"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
ARG AITER_COMMIT="v0.1.
1
"
ARG AITER_COMMIT="v0.1.
2
"
RUN git clone ${SGL_REPO} \
&& cd sglang \
...
...
python/sglang/srt/layers/attention/aiter_backend.py
0 → 100644
View file @
5c0b38f3
from
__future__
import
annotations
"""
end to end attention solution with aiter kernels
"""
import
math
import
os
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
import
torch
import
triton
import
triton.language
as
tl
from
sglang.global_config
import
global_config
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecInfo
try
:
from
aiter
import
mha_batch_prefill_func
,
paged_attention_ragged
except
ImportError
:
print
(
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
)
class
WrapperDispatch
(
Enum
):
SLIDING_WINDOW
=
auto
()
CROSS_ATTENTION
=
auto
()
@
dataclass
class
ForwardMetadata
:
kv_indptr
:
torch
.
Tensor
kv_indices
:
torch
.
Tensor
max_q_len
:
int
max_kv_len
:
int
global_workspace_buffer
=
None
_AITER_PARTITION_SIZE_ROCM
=
256
class
AiterAttnBackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
ModelRunner
,
skip_prefill
:
bool
=
False
,
kv_indptr_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
super
().
__init__
()
self
.
device
=
model_runner
.
device
self
.
is_multimodal
=
model_runner
.
model_config
.
is_multimodal
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
self
.
head_dim
=
model_runner
.
model_config
.
head_dim
self
.
v_head_dim
=
model_runner
.
token_to_kv_pool
.
get_value_buffer
(
0
).
shape
[
-
1
]
self
.
num_kv_head
=
model_runner
.
model_config
.
get_num_kv_heads
(
get_attention_tp_size
()
)
self
.
kv_cache_dtype
=
model_runner
.
kv_cache_dtype
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
# Parse constants
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
skip_prefill
=
skip_prefill
max_bs
=
model_runner
.
req_to_token_pool
.
size
if
kv_indptr_buf
is
None
:
self
.
kv_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
else
:
self
.
kv_indptr
=
kv_indptr_buf
self
.
kv_last_page_len
=
torch
.
ones
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
self
.
qo_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
# Create prefill indices updater
if
not
skip_prefill
:
self
.
indices_updater_prefill
=
AiterIndicesUpdaterPrefill
(
model_runner
,
self
)
# aiter kernel related initialization
self
.
max_num_partitions
=
(
self
.
max_context_len
+
_AITER_PARTITION_SIZE_ROCM
-
1
)
//
_AITER_PARTITION_SIZE_ROCM
nbyes_per_qo_elem
=
torch
.
finfo
(
torch
.
float32
).
bits
//
8
self
.
workspace_buffer
=
torch
.
empty
(
(
max_bs
*
self
.
num_head
*
self
.
max_num_partitions
*
self
.
head_dim
)
*
nbyes_per_qo_elem
+
2
*
(
max_bs
*
self
.
num_head
*
self
.
max_num_partitions
)
*
4
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
,
)
self
.
scale
=
float
(
1.0
/
(
self
.
head_dim
**
0.5
))
self
.
k_scale
=
self
.
v_scale
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
).
to
(
self
.
device
)
self
.
kv_last_page_lens
=
torch
.
ones
((
max_bs
,),
dtype
=
torch
.
int32
).
to
(
self
.
device
)
self
.
logits_soft_cap
=
0.0
self
.
forward_metadata
:
ForwardMetadata
=
None
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
# update for aiter
# create kv_indices and kv_inptr
bs
=
forward_batch
.
batch_size
kv_indptr
=
self
.
kv_indptr
spec_info
=
forward_batch
.
spec_info
if
spec_info
is
None
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
zeros
(
forward_batch
.
seq_lens_sum
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
else
:
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
bs
=
kv_indptr
.
shape
[
0
]
-
1
self
.
forward_metadata
=
ForwardMetadata
(
kv_indptr
,
kv_indices
,
None
,
None
)
elif
forward_batch
.
forward_mode
.
is_draft_extend
():
self
.
indices_updater_prefill
.
update
(
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens_sum
,
prefix_lens
=
None
,
encoder_lens
=
forward_batch
.
encoder_lens
,
spec_info
=
forward_batch
.
spec_info
,
)
self
.
forward_metadata
=
ForwardMetadata
(
self
.
indices_updater_prefill
.
kv_indptr
,
self
.
indices_updater_prefill
.
kv_indices
,
self
.
indices_updater_prefill
.
max_q_len
,
self
.
indices_updater_prefill
.
max_kv_len
,
)
elif
forward_batch
.
forward_mode
.
is_target_verify
():
self
.
indices_updater_prefill
.
update
(
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens_sum
,
prefix_lens
=
None
,
encoder_lens
=
forward_batch
.
encoder_lens
,
spec_info
=
forward_batch
.
spec_info
,
)
self
.
forward_metadata
=
ForwardMetadata
(
self
.
indices_updater_prefill
.
kv_indptr
,
self
.
indices_updater_prefill
.
kv_indices
,
self
.
indices_updater_prefill
.
max_q_len
,
self
.
indices_updater_prefill
.
max_kv_len
,
)
else
:
prefix_lens
=
forward_batch
.
extend_prefix_lens
if
self
.
is_multimodal
:
extend_no_prefix
=
False
else
:
extend_no_prefix
=
not
any
(
forward_batch
.
extend_prefix_lens_cpu
)
self
.
indices_updater_prefill
.
update
(
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens_sum
,
prefix_lens
,
encoder_lens
=
forward_batch
.
encoder_lens
,
spec_info
=
None
,
)
self
.
forward_metadata
=
ForwardMetadata
(
self
.
indices_updater_prefill
.
kv_indptr
,
self
.
indices_updater_prefill
.
kv_indices
,
self
.
indices_updater_prefill
.
max_q_len
,
self
.
indices_updater_prefill
.
max_kv_len
,
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
):
if
kv_indices_buf
is
None
:
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
max_bs
*
self
.
max_context_len
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
else
:
self
.
cuda_graph_kv_indices
=
kv_indices_buf
if
not
self
.
skip_prefill
:
self
.
cuda_graph_custom_mask
=
torch
.
zeros
(
(
max_bs
*
self
.
max_context_len
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
,
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
num_tokens
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
):
if
forward_mode
.
is_decode_or_idle
():
if
spec_info
is
None
:
kv_indptr
=
self
.
kv_indptr
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
else
:
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
self
.
forward_metadata
=
ForwardMetadata
(
kv_indptr
,
kv_indices
,
None
,
None
)
elif
forward_mode
.
is_target_verify
():
seq_lens_sum
=
seq_lens
.
sum
().
item
()
self
.
indices_updater_prefill
.
update
(
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
prefix_lens
=
None
,
encoder_lens
=
encoder_lens
,
spec_info
=
spec_info
,
)
self
.
forward_metadata
=
ForwardMetadata
(
self
.
indices_updater_prefill
.
kv_indptr
,
self
.
indices_updater_prefill
.
kv_indices
,
self
.
indices_updater_prefill
.
max_q_len
,
self
.
indices_updater_prefill
.
max_kv_len
,
)
else
:
raise
ValueError
(
f
"Invalid mode:
{
forward_mode
=
}
"
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
if
forward_mode
.
is_decode_or_idle
():
kv_indptr
=
self
.
kv_indptr
kv_indices
=
self
.
cuda_graph_kv_indices
if
spec_info
is
None
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
[:
bs
],
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
else
:
kv_indptr
[:
spec_info
.
kv_indptr
.
shape
[
0
]]
=
spec_info
.
kv_indptr
kv_indices
[:
spec_info
.
kv_indices
.
shape
[
0
]]
=
spec_info
.
kv_indices
elif
forward_mode
.
is_target_verify
():
self
.
indices_updater_prefill
.
update
(
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
seq_lens_sum
,
prefix_lens
=
None
,
encoder_lens
=
encoder_lens
[:
bs
]
if
encoder_lens
is
not
None
else
None
,
spec_info
=
spec_info
,
)
else
:
raise
ValueError
(
"Invalid forward mode"
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
def
forward_extend
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
):
cache_loc
=
(
forward_batch
.
out_cache_loc
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
)
self
.
logits_soft_cap
=
layer
.
logit_cap
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
k_cache
,
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
bs0
=
forward_batch
.
batch_size
+
1
o
=
mha_batch_prefill_func
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k_cache
,
v_cache
,
self
.
qo_indptr
[:
bs0
],
self
.
forward_metadata
.
kv_indptr
[:
bs0
],
self
.
forward_metadata
.
kv_indices
,
self
.
forward_metadata
.
max_q_len
,
self
.
forward_metadata
.
max_kv_len
,
causal
=
True
,
logits_soft_cap
=
self
.
logits_soft_cap
,
alibi_slopes
=
None
,
return_lse
=
False
,
return_attn_probs
=
False
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
def
forward_decode
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
):
q
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
self
.
logits_soft_cap
=
layer
.
logit_cap
paged_attention_ragged
(
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
self
.
workspace_buffer
,
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
).
view
(
-
1
,
1
,
layer
.
tp_k_head_num
,
layer
.
qk_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
).
view
(
-
1
,
1
,
layer
.
tp_v_head_num
,
layer
.
v_head_dim
),
self
.
scale
,
self
.
forward_metadata
.
kv_indptr
,
self
.
forward_metadata
.
kv_indices
,
self
.
kv_last_page_lens
,
1
,
self
.
max_num_partitions
,
None
,
"auto"
,
"NHD"
,
self
.
logits_soft_cap
,
self
.
k_scale
,
self
.
v_scale
,
None
,
_AITER_PARTITION_SIZE_ROCM
,
)
return
o
class
AiterIndicesUpdaterPrefill
:
def
__init__
(
self
,
model_runner
:
ModelRunner
,
attn_backend
:
AttentionBackend
):
# Parse Constants
self
.
num_qo_heads
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
self
.
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
get_attention_tp_size
()
)
self
.
head_dim
=
model_runner
.
model_config
.
head_dim
self
.
data_type
=
model_runner
.
kv_cache_dtype
self
.
q_data_type
=
model_runner
.
dtype
self
.
sliding_window_size
=
model_runner
.
sliding_window_size
self
.
attn_backend
=
attn_backend
# Buffers and wrappers
self
.
kv_indptr
=
attn_backend
.
kv_indptr
self
.
kv_last_page_len
=
attn_backend
.
kv_last_page_len
self
.
qo_indptr
=
attn_backend
.
qo_indptr
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
update
=
self
.
update_single_wrapper
self
.
kv_indices
=
None
self
.
max_q_len
=
0
self
.
max_kv_len
=
0
def
update
(
self
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
prefix_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise
NotImplementedError
()
def
update_single_wrapper
(
self
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
prefix_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
):
kv_start_idx
=
None
kv_indptr
=
self
.
kv_indptr
qo_indptr
=
self
.
qo_indptr
paged_kernel_lens
=
seq_lens
paged_kernel_lens_sum
=
seq_lens_sum
bs
=
len
(
req_pool_indices
)
if
spec_info
is
None
:
# Normal extend
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
paged_kernel_lens_sum
+
256
,
dtype
=
torch
.
int32
,
device
=
req_pool_indices
.
device
,
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
paged_kernel_lens
,
kv_indptr
,
kv_start_idx
,
kv_indices
,
self
.
req_to_token
.
shape
[
1
],
)
self
.
max_kv_len
=
torch
.
max
(
paged_kernel_lens
).
item
()
extend_lens
=
seq_lens
-
prefix_lens
self
.
max_q_len
=
torch
.
max
(
extend_lens
).
item
()
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
extend_lens
,
dim
=
0
)
qo_indptr
=
qo_indptr
[:
bs
+
1
]
custom_mask
=
None
else
:
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
spec_info
.
generate_attn_arg_prefill
(
req_pool_indices
,
paged_kernel_lens
,
self
.
req_to_token
,
)
)
self
.
kv_indices
=
kv_indices
python/sglang/srt/model_executor/model_runner.py
View file @
5c0b38f3
...
...
@@ -103,6 +103,8 @@ from sglang.srt.utils import (
set_cuda_arch
,
)
_is_hip
=
is_hip
()
# Use a small KV cache pool size for tests in CI
SGLANG_CI_SMALL_KV_SIZE
=
os
.
getenv
(
"SGLANG_CI_SMALL_KV_SIZE"
,
None
)
...
...
@@ -318,6 +320,8 @@ class ModelRunner:
and
is_fa3_default_architecture
(
self
.
model_config
.
hf_config
)
):
server_args
.
attention_backend
=
"fa3"
elif
_is_hip
:
server_args
.
attention_backend
=
"aiter"
else
:
server_args
.
attention_backend
=
(
"flashinfer"
if
is_flashinfer_available
()
else
"triton"
...
...
@@ -794,7 +798,7 @@ class ModelRunner:
if
self
.
server_args
.
kv_cache_dtype
==
"auto"
:
self
.
kv_cache_dtype
=
self
.
dtype
elif
self
.
server_args
.
kv_cache_dtype
==
"fp8_e5m2"
:
if
is_hip
()
:
# Using natively supported format
if
_
is_hip
:
# Using natively supported format
self
.
kv_cache_dtype
=
torch
.
float8_e5m2fnuz
else
:
self
.
kv_cache_dtype
=
torch
.
float8_e5m2
...
...
@@ -972,6 +976,10 @@ class ModelRunner:
)
self
.
attn_backend
=
FlashInferMLAAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"aiter"
:
from
sglang.srt.layers.attention.aiter_backend
import
AiterAttnBackend
self
.
attn_backend
=
AiterAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"triton"
:
assert
self
.
sliding_window_size
is
None
,
(
"Window attention is not supported in the triton attention backend. "
...
...
python/sglang/srt/server_args.py
View file @
5c0b38f3
...
...
@@ -957,6 +957,7 @@ class ServerArgs:
"--attention-backend"
,
type
=
str
,
choices
=
[
"aiter"
,
"flashinfer"
,
"triton"
,
"torch_native"
,
...
...
scripts/amd_ci_install_dependency.sh
View file @
5c0b38f3
...
...
@@ -5,7 +5,6 @@ set -euo pipefail
docker
exec
ci_sglang pip
install
--upgrade
pip
docker
exec
ci_sglang pip uninstall sgl-kernel
-y
||
true
docker
exec
-w
/sglang-checkout/sgl-kernel ci_sglang bash
-c
"rm -f pyproject.toml && mv pyproject_rocm.toml pyproject.toml && python3 setup_rocm.py install"
docker
exec
ci_sglang pip
install
-e
"python[dev_hip]"
docker
exec
-w
/ ci_sglang git clone https://github.com/merrymercy/human-eval.git
docker
exec
-w
/human-eval ci_sglang pip
install
-e
.
...
...
scripts/amd_ci_start_container.sh
View file @
5c0b38f3
...
...
@@ -9,7 +9,7 @@ else
fi
# Pull the image
IMAGE
=
"
lmsysorg/sglang:v0.4.6.post3-rocm630
"
IMAGE
=
"
ghcr.io/saienduri/sglang-aiter-backend-v0.1.2:518
"
echo
"Pulling Docker image:
$IMAGE
"
docker pull
"
$IMAGE
"
...
...
test/srt/models/test_dummy_grok_models.py
View file @
5c0b38f3
...
...
@@ -4,6 +4,7 @@ from sglang.test.test_utils import CustomTestCase, is_in_ci, run_bench_one_batch
class
TestDummyGrok1
(
CustomTestCase
):
def
test_dummy_grok_1
(
self
):
output_throughput
=
run_bench_one_batch
(
None
,
...
...
test/srt/test_eval_accuracy_large.py
View file @
5c0b38f3
...
...
@@ -3,6 +3,8 @@ Usage:
python -m unittest test_eval_accuracy_large.TestEvalAccuracyLarge.test_mmlu
"""
import
os
import
time
import
unittest
from
types
import
SimpleNamespace
...
...
@@ -35,6 +37,11 @@ class TestEvalAccuracyLarge(CustomTestCase):
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
tearDown
(
self
):
# Delay between tests to allow GPU memory cleanup
if
os
.
getenv
(
"SGLANG_AMD_CI"
)
==
"1"
:
time
.
sleep
(
180
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment