Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0beea450
Unverified
Commit
0beea450
authored
Mar 07, 2025
by
HAI
Committed by
GitHub
Mar 07, 2025
Browse files
ROCm: Flex Attention Enablement with custom backends (#4178)
Co-authored-by:
linsun12
<
linsun12@amd.com
>
parent
c827c671
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1435 additions
and
36 deletions
+1435
-36
docker/Dockerfile.rocm
docker/Dockerfile.rocm
+3
-2
python/sglang/srt/layers/attention/aiter_backend.py
python/sglang/srt/layers/attention/aiter_backend.py
+605
-0
python/sglang/srt/layers/attention/aiter_decode_backend.py
python/sglang/srt/layers/attention/aiter_decode_backend.py
+535
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+59
-27
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+17
-7
sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.hip
sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.hip
+118
-0
sgl-kernel/src/sgl-kernel/include/utils_hip.h
sgl-kernel/src/sgl-kernel/include/utils_hip.h
+98
-0
No files found.
docker/Dockerfile.rocm
View file @
0beea450
...
...
@@ -2,7 +2,7 @@
# docker build --build-arg SGL_BRANCH=v0.4.3.post4 -t v0.4.3.post4-rocm630 -f Dockerfile.rocm .
# default base image
ARG BASE_IMAGE="rocm/sgl-dev:
vllm
20250114"
ARG BASE_IMAGE="rocm/sgl-dev:20250114
vllm-blas-flash
"
FROM $BASE_IMAGE AS base
USER root
...
...
@@ -16,10 +16,10 @@ ARG SGL_BRANCH=${SGL_DEFAULT}
ARG TRITON_REPO="https://github.com/ROCm/triton.git"
ARG TRITON_COMMIT="improve_fa_decode_3.0.0"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
ARG AITER_COMMIT="testx"
RUN git clone ${SGL_REPO} \
&& cd sglang \
&& if [ "${SGL_BRANCH}" = ${SGL_DEFAULT} ]; then \
...
...
@@ -59,6 +59,7 @@ RUN git clone ${AITER_REPO} \
&& git submodule update --init --recursive \
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop
# Copy config files to support MI300X in virtualized environments (MI300X_VF). Symlinks will not be created in image build.
RUN find /sgl-workspace/sglang/python/sglang/srt/layers/quantization/configs/ \
/sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \
...
...
python/sglang/srt/layers/attention/aiter_backend.py
0 → 100644
View file @
0beea450
This diff is collapsed.
Click to expand it.
python/sglang/srt/layers/attention/aiter_decode_backend.py
0 → 100644
View file @
0beea450
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecInfo
try
:
from
aiter
import
paged_attention_rocm
except
ImportError
:
print
(
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
)
from
sglang.srt.layers.attention.triton_ops.extend_attention
import
extend_attention_fwd
_AITER_PARTITION_SIZE_ROCM
=
256
class
AiterDecodeAttnBackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
ModelRunner
,
skip_prefill
:
bool
=
False
,
kv_indptr_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
super
().
__init__
()
self
.
decode_attention_fwd
=
paged_attention_rocm
self
.
extend_attention_fwd
=
extend_attention_fwd
self
.
skip_prefill
=
skip_prefill
max_bs
=
model_runner
.
req_to_token_pool
.
size
if
kv_indptr_buf
is
None
:
self
.
kv_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
else
:
self
.
kv_indptr
=
kv_indptr_buf
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
if
not
self
.
skip_prefill
:
self
.
qo_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
self
.
mask_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int64
,
device
=
model_runner
.
device
)
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
# tp sharding on number of heads
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
self
.
head_dim
=
model_runner
.
model_config
.
head_dim
# triton prefill initialization
self
.
num_kv_splits
=
model_runner
.
server_args
.
triton_attention_num_kv_splits
self
.
v_head_dim
=
model_runner
.
token_to_kv_pool
.
get_value_buffer
(
0
).
shape
[
-
1
]
self
.
num_v_head
=
model_runner
.
token_to_kv_pool
.
get_value_buffer
(
0
).
shape
[
-
2
]
self
.
forward_metadata
=
None
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
device
=
model_runner
.
device
self
.
kv_cache_dtype
=
model_runner
.
kv_cache_dtype
self
.
q_dtype
=
model_runner
.
model_config
.
dtype
# aiter decode initialization
self
.
max_num_partitions
=
(
self
.
max_context_len
+
_AITER_PARTITION_SIZE_ROCM
-
1
)
//
_AITER_PARTITION_SIZE_ROCM
nbyes_per_qo_elem
=
torch
.
finfo
(
torch
.
float32
).
bits
//
8
self
.
workspace_buffer
=
torch
.
empty
(
(
max_bs
*
self
.
num_head
*
self
.
max_num_partitions
*
self
.
head_dim
)
*
nbyes_per_qo_elem
+
2
*
(
max_bs
*
self
.
num_head
*
self
.
max_num_partitions
)
*
4
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
,
)
self
.
scale
=
float
(
1.0
/
(
self
.
head_dim
**
0.5
))
self
.
k_scale
=
self
.
v_scale
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
).
to
(
self
.
device
)
self
.
kv_last_page_lens
=
torch
.
ones
((
max_bs
,),
dtype
=
torch
.
int32
).
to
(
self
.
device
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init auxiliary variables"""
bs
=
forward_batch
.
batch_size
kv_indptr
=
self
.
kv_indptr
spec_info
=
forward_batch
.
spec_info
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
if
spec_info
is
None
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
zeros
(
forward_batch
.
seq_lens_sum
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# prepare kv_indices and kv_indptr
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
else
:
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
bs
=
kv_indptr
.
shape
[
0
]
-
1
attn_logits
=
None
# accomodate forward_metadata format
qo_indptr
=
None
custom_mask
=
None
mask_indptr
=
None
max_extend_len
=
None
elif
forward_batch
.
forward_mode
.
is_target_verify
():
bs
=
len
(
forward_batch
.
req_pool_indices
)
qo_indptr
=
torch
.
arange
(
0
,
(
1
+
bs
)
*
self
.
num_draft_tokens
,
step
=
self
.
num_draft_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
zeros
(
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
custom_mask
=
spec_info
.
custom_mask
seq_mask_len
=
self
.
num_draft_tokens
*
(
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
)
mask_indptr
=
self
.
mask_indptr
mask_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_mask_len
[:
bs
],
dim
=
0
)
mask_indptr
=
mask_indptr
[:
bs
+
1
]
max_extend_len
=
self
.
num_draft_tokens
attn_logits
=
None
elif
forward_batch
.
forward_mode
.
is_draft_extend
():
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
spec_info
.
generate_attn_arg_prefill
(
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
self
.
req_to_token
,
)
)
mask_indptr
=
None
max_extend_len
=
torch
.
max
(
spec_info
.
accept_length
).
item
()
attn_logits
=
None
else
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
extend_prefix_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
zeros
(
forward_batch
.
extend_prefix_lens
.
sum
().
item
(),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
extend_prefix_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
qo_indptr
=
self
.
qo_indptr
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
extend_seq_lens
,
dim
=
0
)
qo_indptr
=
qo_indptr
[:
bs
+
1
]
custom_mask
=
None
mask_indptr
=
None
attn_logits
=
None
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
self
.
forward_metadata
=
(
attn_logits
,
max_extend_len
,
kv_indptr
,
kv_indices
,
qo_indptr
,
custom_mask
,
mask_indptr
,
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
):
self
.
cuda_graph_attn_logits
=
torch
.
zeros
(
(
max_bs
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
if
kv_indices_buf
is
None
:
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
max_bs
*
self
.
max_context_len
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
else
:
self
.
cuda_graph_kv_indices
=
kv_indices_buf
if
not
self
.
skip_prefill
:
self
.
cuda_graph_custom_mask
=
torch
.
zeros
(
(
max_bs
*
self
.
max_context_len
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
,
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
num_tokens
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
):
assert
encoder_lens
is
None
,
"Not supported"
if
forward_mode
.
is_decode_or_idle
():
if
spec_info
is
None
:
kv_indptr
=
self
.
kv_indptr
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
else
:
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
attn_logits
=
None
max_extend_len
=
None
qo_indptr
=
None
custom_mask
=
None
mask_indptr
=
None
elif
forward_mode
.
is_target_verify
():
qo_indptr
=
self
.
qo_indptr
[:
bs
+
1
]
qo_indptr
[:
bs
+
1
]
=
torch
.
arange
(
0
,
(
1
+
bs
)
*
self
.
num_draft_tokens
,
step
=
self
.
num_draft_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
kv_indptr
=
self
.
kv_indptr
[:
bs
+
1
]
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
custom_mask
=
self
.
cuda_graph_custom_mask
seq_mask_len
=
self
.
num_draft_tokens
*
(
seq_lens
+
self
.
num_draft_tokens
)
mask_indptr
=
self
.
mask_indptr
[:
bs
+
1
]
mask_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_mask_len
,
dim
=
0
)
max_extend_len
=
self
.
num_draft_tokens
attn_logits
=
None
else
:
raise
ValueError
(
f
"Invalid forward mode:
{
forward_mode
=
}
for CUDA Graph capture."
)
self
.
forward_metadata
=
(
attn_logits
,
max_extend_len
,
kv_indptr
,
kv_indices
,
qo_indptr
,
custom_mask
,
mask_indptr
,
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
# NOTE: encoder_lens expected to be zeros or None
if
forward_mode
.
is_decode_or_idle
():
# Update kv_indptr, kv_indices
kv_indptr
=
self
.
kv_indptr
kv_indices
=
self
.
cuda_graph_kv_indices
if
spec_info
is
None
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
[:
bs
],
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
else
:
kv_indptr
[:
spec_info
.
kv_indptr
.
shape
[
0
]]
=
spec_info
.
kv_indptr
kv_indices
[:
spec_info
.
kv_indices
.
shape
[
0
]]
=
spec_info
.
kv_indices
elif
forward_mode
.
is_target_verify
():
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
bs
=
len
(
req_pool_indices
)
qo_indptr
=
self
.
qo_indptr
[:
bs
+
1
]
qo_indptr
[:
bs
+
1
]
=
torch
.
arange
(
0
,
(
1
+
bs
)
*
self
.
num_draft_tokens
,
step
=
self
.
num_draft_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
kv_indptr
=
self
.
kv_indptr
[:
bs
+
1
]
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
custom_mask
=
self
.
cuda_graph_custom_mask
custom_mask
[:
spec_info
.
custom_mask
.
shape
[
0
]]
=
spec_info
.
custom_mask
seq_mask_len
=
self
.
num_draft_tokens
*
(
seq_lens
+
self
.
num_draft_tokens
)
mask_indptr
=
self
.
mask_indptr
[:
bs
+
1
]
mask_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_mask_len
,
dim
=
0
)
else
:
raise
ValueError
(
f
"Invalid forward mode:
{
forward_mode
=
}
for CUDA Graph replay."
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
def
forward_extend
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
):
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
(
_
,
max_extend_len
,
kv_indptr
,
kv_indices
,
qo_indptr
,
custom_mask
,
mask_indptr
,
)
=
self
.
forward_metadata
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
v
.
contiguous
(),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask
,
mask_indptr
,
max_extend_len
,
layer
.
scaling
,
layer
.
logit_cap
,
)
return
o
def
forward_decode
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
attn_logits
,
_
,
kv_indptr
,
kv_indices
,
_
,
_
,
_
=
self
.
forward_metadata
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
self
.
decode_attention_fwd
(
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
# (bs, head_num_q, head_dim_q)
self
.
workspace_buffer
,
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
).
view
(
-
1
,
1
,
layer
.
tp_k_head_num
,
layer
.
qk_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
).
view
(
-
1
,
1
,
layer
.
tp_v_head_num
,
layer
.
v_head_dim
),
self
.
scale
,
kv_indptr
,
kv_indices
,
self
.
kv_last_page_lens
,
1
,
self
.
max_num_partitions
,
None
,
"auto"
,
"NHD"
,
layer
.
logit_cap
,
self
.
k_scale
,
self
.
v_scale
,
None
,
_AITER_PARTITION_SIZE_ROCM
,
)
return
o
@
triton
.
jit
def
create_flashinfer_kv_indices_triton
(
req_to_token_ptr
,
# [max_batch, max_context_len]
req_pool_indices_ptr
,
page_kernel_lens_ptr
,
kv_indptr
,
kv_start_idx
,
kv_indices_ptr
,
req_to_token_ptr_stride
:
tl
.
constexpr
,
):
BLOCK_SIZE
:
tl
.
constexpr
=
512
pid
=
tl
.
program_id
(
axis
=
0
)
req_pool_index
=
tl
.
load
(
req_pool_indices_ptr
+
pid
)
kv_indices_offset
=
tl
.
load
(
kv_indptr
+
pid
)
kv_start
=
0
kv_end
=
0
if
kv_start_idx
:
kv_start
=
tl
.
load
(
kv_start_idx
+
pid
).
to
(
tl
.
int32
)
kv_end
=
kv_start
kv_end
+=
tl
.
load
(
page_kernel_lens_ptr
+
pid
).
to
(
tl
.
int32
)
num_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
for
i
in
range
(
num_loop
):
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
i
*
BLOCK_SIZE
mask
=
offset
<
kv_end
-
kv_start
data
=
tl
.
load
(
req_to_token_ptr
+
req_pool_index
*
req_to_token_ptr_stride
+
kv_start
+
offset
,
mask
=
mask
,
)
tl
.
store
(
kv_indices_ptr
+
kv_indices_offset
+
offset
,
data
,
mask
=
mask
)
python/sglang/srt/model_executor/model_runner.py
View file @
0beea450
...
...
@@ -79,6 +79,12 @@ from sglang.srt.utils import (
)
from
sglang.utils
import
get_exception_traceback
is_hip_
=
is_hip
()
if
is_hip_
:
from
sglang.srt.layers.attention.aiter_backend
import
AiterAttnBackend
from
sglang.srt.layers.attention.aiter_decode_backend
import
AiterDecodeAttnBackend
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -641,7 +647,7 @@ class ModelRunner:
if
self
.
server_args
.
kv_cache_dtype
==
"auto"
:
self
.
kv_cache_dtype
=
self
.
dtype
elif
self
.
server_args
.
kv_cache_dtype
==
"fp8_e5m2"
:
if
is_hip
()
:
# Using natively supported format
if
is_hip
_
:
# Using natively supported format
self
.
kv_cache_dtype
=
torch
.
float8_e5m2fnuz
else
:
self
.
kv_cache_dtype
=
torch
.
float8_e5m2
...
...
@@ -778,33 +784,59 @@ class ModelRunner:
def
init_attention_backend
(
self
):
"""Init attention kernel backend."""
if
self
.
server_args
.
attention_backend
==
"flashinfer"
:
# Init streams
if
self
.
server_args
.
speculative_algorithm
==
"EAGLE"
:
self
.
plan_stream_for_flashinfer
=
torch
.
cuda
.
Stream
()
self
.
attn_backend
=
FlashInferAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"triton"
:
assert
self
.
sliding_window_size
is
None
,
(
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
assert
not
self
.
model_config
.
is_encoder_decoder
,
(
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
if
self
.
server_args
.
enable_double_sparsity
:
self
.
attn_backend
=
DoubleSparseAttnBackend
(
self
)
if
is_cuda
():
if
self
.
server_args
.
attention_backend
==
"flashinfer"
:
# Init streams
if
self
.
server_args
.
speculative_algorithm
==
"EAGLE"
:
self
.
plan_stream_for_flashinfer
=
torch
.
cuda
.
Stream
()
self
.
attn_backend
=
FlashInferAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"triton"
:
assert
self
.
sliding_window_size
is
None
,
(
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
assert
not
self
.
model_config
.
is_encoder_decoder
,
(
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
if
self
.
server_args
.
enable_double_sparsity
:
self
.
attn_backend
=
DoubleSparseAttnBackend
(
self
)
else
:
self
.
attn_backend
=
TritonAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"torch_native"
:
self
.
attn_backend
=
TorchNativeAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"flashinfer_mla"
:
self
.
attn_backend
=
FlashInferMLAAttnBackend
(
self
)
else
:
self
.
attn_backend
=
TritonAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"torch_native"
:
self
.
attn_backend
=
TorchNativeAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"flashinfer_mla"
:
self
.
attn_backend
=
FlashInferMLAAttnBackend
(
self
)
else
:
raise
ValueError
(
f
"Invalid attention backend:
{
self
.
server_args
.
attention_backend
}
"
)
raise
ValueError
(
f
"Invalid attention backend:
{
self
.
server_args
.
attention_backend
}
"
)
elif
is_hip_
:
# AMD hip supported attention backends
if
self
.
server_args
.
attention_backend
==
"aiter"
:
self
.
attn_backend
=
AiterAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"aiter_decode"
:
self
.
attn_backend
=
AiterDecodeAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"triton"
:
assert
self
.
sliding_window_size
is
None
,
(
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
assert
not
self
.
model_config
.
is_encoder_decoder
,
(
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
if
self
.
server_args
.
enable_double_sparsity
:
self
.
attn_backend
=
DoubleSparseAttnBackend
(
self
)
else
:
self
.
attn_backend
=
TritonAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"torch_native"
:
self
.
attn_backend
=
TorchNativeAttnBackend
(
self
)
else
:
raise
ValueError
(
f
"Invalid attention backend:
{
self
.
server_args
.
attention_backend
}
"
)
def
init_double_sparsity_channel_config
(
self
,
selected_channel
):
selected_channel
=
"."
+
selected_channel
+
"_proj"
...
...
python/sglang/srt/server_args.py
View file @
0beea450
...
...
@@ -710,13 +710,23 @@ class ServerArgs:
)
# Kernel backend
parser
.
add_argument
(
"--attention-backend"
,
type
=
str
,
choices
=
[
"flashinfer"
,
"triton"
,
"torch_native"
],
default
=
ServerArgs
.
attention_backend
,
help
=
"Choose the kernels for attention layers."
,
)
if
is_hip
():
parser
.
add_argument
(
"--attention-backend"
,
type
=
str
,
choices
=
[
"triton"
,
"torch_native"
,
"aiter"
,
"aiter_decode"
],
default
=
ServerArgs
.
attention_backend
,
help
=
"Choose the kernels for attention layers."
,
)
else
:
parser
.
add_argument
(
"--attention-backend"
,
type
=
str
,
choices
=
[
"flashinfer"
,
"triton"
,
"torch_native"
],
default
=
ServerArgs
.
attention_backend
,
help
=
"Choose the kernels for attention layers."
,
)
parser
.
add_argument
(
"--sampling-backend"
,
type
=
str
,
...
...
sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.hip
0 → 100644
View file @
0beea450
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <ATen/ATen.h>
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <torch/extension.h>
#include <THH/THHAtomics.cuh>
#include "utils_hip.h"
#define WARP_SIZE 32
template <typename scalar_t>
__global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids,
int32_t* __restrict__ cumsum_buffer, size_t numel) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i];
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
sorted_token_ids[rank_post_pad] = i;
}
}
template <typename scalar_t>
__global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
int32_t block_size, size_t numel, int32_t* __restrict__ cumsum) {
__shared__ int32_t shared_counts[WARP_SIZE][8];
const int warp_id = threadIdx.x / WARP_SIZE;
const int experts_per_warp = 8;
const int my_expert_start = warp_id * experts_per_warp;
for (int i = 0; i < experts_per_warp; ++i) {
if (my_expert_start + i < num_experts) {
shared_counts[warp_id][i] = 0;
}
}
__syncthreads();
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int expert_id = topk_ids[i];
int warp_idx = expert_id / experts_per_warp;
int expert_offset = expert_id % experts_per_warp;
atomicAdd(&shared_counts[warp_idx][expert_offset], 1);
}
__syncthreads();
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
int expert_count = 0;
int warp_idx = (i - 1) / experts_per_warp;
int expert_offset = (i - 1) % experts_per_warp;
expert_count = shared_counts[warp_idx][expert_offset];
cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
}
*total_tokens_post_pad = cumsum[num_experts];
}
__syncthreads();
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
}
}
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) {
const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now.");
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
hipLaunchKernelGGL(( align_kernel), dim3(1), dim3(1024), 0, stream, topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
const int block_threads = 256;
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
const int max_blocks = 65535;
const int actual_blocks = ::min(num_blocks, max_blocks);
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
hipLaunchKernelGGL(( sort_kernel), dim3(actual_blocks), dim3(block_threads), 0, stream, topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
});
}
sgl-kernel/src/sgl-kernel/include/utils_hip.h
0 → 100644
View file @
0beea450
// !!! This is a file automatically generated by hipify!!!
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#pragma once
#include <hip/hip_runtime.h>
#ifndef USE_ROCM
#include <pytorch_extension_utils.h>
#endif
#include <torch/extension.h>
#include <sstream>
struct
cuda_error
:
public
std
::
runtime_error
{
/**
* @brief Constructs a `cuda_error` object with the given `message`.
*
* @param message The error char array used to construct `cuda_error`
*/
cuda_error
(
const
char
*
message
)
:
std
::
runtime_error
(
message
)
{}
/**
* @brief Constructs a `cuda_error` object with the given `message` string.
*
* @param message The `std::string` used to construct `cuda_error`
*/
cuda_error
(
std
::
string
const
&
message
)
:
cuda_error
{
message
.
c_str
()}
{}
};
#define CHECK_CUDA_SUCCESS(cmd) \
do { \
hipError_t e = cmd; \
if (e != hipSuccess) { \
std::stringstream _message; \
auto s = hipGetErrorString(e); \
_message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \
throw cuda_error(_message.str()); \
} \
} while (0)
#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CUDA_INPUT(x) \
CHECK_IS_CUDA(x); \
CHECK_IS_CONTIGUOUS(x)
inline
int
getSMVersion
()
{
int
device
{
-
1
};
CHECK_CUDA_SUCCESS
(
hipGetDevice
(
&
device
));
int
sm_major
=
0
;
int
sm_minor
=
0
;
CHECK_CUDA_SUCCESS
(
hipDeviceGetAttribute
(
&
sm_major
,
hipDeviceAttributeComputeCapabilityMajor
,
device
));
CHECK_CUDA_SUCCESS
(
hipDeviceGetAttribute
(
&
sm_minor
,
hipDeviceAttributeComputeCapabilityMinor
,
device
));
return
sm_major
*
10
+
sm_minor
;
}
#ifndef USE_ROCM
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
case at::ScalarType::Float: { \
using c_type = float; \
return __VA_ARGS__(); \
} \
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#endif
#define DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define CEILDIV(x, y) (((x) + (y)-1) / (y))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment