Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6a7528e6
Unverified
Commit
6a7528e6
authored
Aug 01, 2025
by
Trevor Morris
Committed by
GitHub
Aug 01, 2025
Browse files
[bugfix] Fix page size for create_flashmla_kv_indices_triton() for cutlass mla (#8685)
parent
2ae95d17
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
9 deletions
+9
-9
python/sglang/srt/layers/attention/cutlass_mla_backend.py
python/sglang/srt/layers/attention/cutlass_mla_backend.py
+3
-3
python/sglang/srt/layers/attention/trtllm_mla_backend.py
python/sglang/srt/layers/attention/trtllm_mla_backend.py
+6
-6
No files found.
python/sglang/srt/layers/attention/cutlass_mla_backend.py
View file @
6a7528e6
...
@@ -102,7 +102,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
...
@@ -102,7 +102,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
block_kv_indices
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
max_seqlen_pad
,
PAGE_SIZE
,
PAGED_SIZE
=
PAGE_SIZE
,
)
)
workspace_size
=
cutlass_mla_get_workspace_size
(
workspace_size
=
cutlass_mla_get_workspace_size
(
max_seqlen_pad
*
PAGE_SIZE
,
bs
,
num_kv_splits
=
1
max_seqlen_pad
*
PAGE_SIZE
,
bs
,
num_kv_splits
=
1
...
@@ -165,7 +165,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
...
@@ -165,7 +165,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
self
.
cuda_graph_kv_indices
,
self
.
cuda_graph_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
PAGE_SIZE
,
PAGED_SIZE
=
PAGE_SIZE
,
)
)
self
.
forward_metadata
=
CutlassMLADecodeMetadata
(
self
.
forward_metadata
=
CutlassMLADecodeMetadata
(
self
.
cuda_graph_mla_workspace
,
self
.
cuda_graph_mla_workspace
,
...
@@ -206,7 +206,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
...
@@ -206,7 +206,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
self
.
cuda_graph_kv_indices
,
self
.
cuda_graph_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
PAGE_SIZE
,
PAGED_SIZE
=
PAGE_SIZE
,
)
)
else
:
else
:
super
().
init_forward_metadata_replay_cuda_graph
(
super
().
init_forward_metadata_replay_cuda_graph
(
...
...
python/sglang/srt/layers/attention/trtllm_mla_backend.py
View file @
6a7528e6
...
@@ -147,8 +147,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -147,8 +147,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
block_kv_indices
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
max_blocks
,
max_blocks
,
TRITON_PAD_NUM_PAGE_PER_BLOCK
,
NUM_PAGE_PER_BLOCK
=
TRITON_PAD_NUM_PAGE_PER_BLOCK
,
self
.
page_size
,
PAGED_SIZE
=
self
.
page_size
,
)
)
return
block_kv_indices
return
block_kv_indices
...
@@ -204,8 +204,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -204,8 +204,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
block_kv_indices
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
max_seqlen_pad
,
TRITON_PAD_NUM_PAGE_PER_BLOCK
,
NUM_PAGE_PER_BLOCK
=
TRITON_PAD_NUM_PAGE_PER_BLOCK
,
self
.
page_size
,
PAGED_SIZE
=
self
.
page_size
,
)
)
metadata
=
TRTLLMMLADecodeMetadata
(
self
.
cuda_graph_workspace
,
block_kv_indices
)
metadata
=
TRTLLMMLADecodeMetadata
(
self
.
cuda_graph_workspace
,
block_kv_indices
)
...
@@ -248,8 +248,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -248,8 +248,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
metadata
.
block_kv_indices
,
metadata
.
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
metadata
.
block_kv_indices
.
shape
[
1
],
metadata
.
block_kv_indices
.
shape
[
1
],
TRITON_PAD_NUM_PAGE_PER_BLOCK
,
NUM_PAGE_PER_BLOCK
=
TRITON_PAD_NUM_PAGE_PER_BLOCK
,
self
.
page_size
,
PAGED_SIZE
=
self
.
page_size
,
)
)
def
get_cuda_graph_seq_len_fill_value
(
self
)
->
int
:
def
get_cuda_graph_seq_len_fill_value
(
self
)
->
int
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment