Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c23eda85
Unverified
Commit
c23eda85
authored
Oct 22, 2025
by
yinghui
Committed by
GitHub
Oct 22, 2025
Browse files
Fix incorrect KV indices creation when page_size=32 in TRTLLM MLA backend (#11985)
parent
138ff231
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
70 additions
and
73 deletions
+70
-73
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+6
-7
python/sglang/srt/layers/attention/trtllm_mla_backend.py
python/sglang/srt/layers/attention/trtllm_mla_backend.py
+4
-6
python/sglang/srt/layers/attention/utils.py
python/sglang/srt/layers/attention/utils.py
+11
-7
python/sglang/test/attention/test_trtllm_mla_backend.py
python/sglang/test/attention/test_trtllm_mla_backend.py
+49
-53
No files found.
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
c23eda85
...
@@ -9,19 +9,12 @@ and uses BatchMLAPaged wrapper for decoding.
...
@@ -9,19 +9,12 @@ and uses BatchMLAPaged wrapper for decoding.
More details can be found in https://docs.flashinfer.ai/api/mla.html
More details can be found in https://docs.flashinfer.ai/api/mla.html
"""
"""
import
os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
partial
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Union
import
torch
import
torch
if
os
.
environ
[
"SGLANG_ENABLE_TORCH_COMPILE"
]
==
"1"
:
import
logging
torch
.
_logging
.
set_logs
(
dynamo
=
logging
.
ERROR
)
torch
.
_dynamo
.
config
.
suppress_errors
=
True
from
sglang.srt.environ
import
envs
from
sglang.srt.environ
import
envs
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
(
from
sglang.srt.layers.attention.flashinfer_backend
import
(
...
@@ -45,6 +38,12 @@ if TYPE_CHECKING:
...
@@ -45,6 +38,12 @@ if TYPE_CHECKING:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecInput
from
sglang.srt.speculative.spec_info
import
SpecInput
if
envs
.
SGLANG_ENABLE_TORCH_COMPILE
.
get
():
import
logging
torch
.
_logging
.
set_logs
(
dynamo
=
logging
.
ERROR
)
torch
.
_dynamo
.
config
.
suppress_errors
=
True
if
is_flashinfer_available
():
if
is_flashinfer_available
():
from
flashinfer
import
(
from
flashinfer
import
(
BatchMLAPagedAttentionWrapper
,
BatchMLAPagedAttentionWrapper
,
...
...
python/sglang/srt/layers/attention/trtllm_mla_backend.py
View file @
c23eda85
...
@@ -17,8 +17,8 @@ from sglang.srt.layers.attention.flashinfer_mla_backend import (
...
@@ -17,8 +17,8 @@ from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAMultiStepDraftBackend
,
FlashInferMLAMultiStepDraftBackend
,
)
)
from
sglang.srt.layers.attention.utils
import
(
from
sglang.srt.layers.attention.utils
import
(
TRITON_PAD_NUM_PAGE_PER_BLOCK
,
create_flashmla_kv_indices_triton
,
create_flashmla_kv_indices_triton
,
get_num_page_per_block_flashmla
,
)
)
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
...
@@ -295,9 +295,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -295,9 +295,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
# Apply dual constraints (take LCM to satisfy both):
# Apply dual constraints (take LCM to satisfy both):
# 1. TRT-LLM: block_num % (128 / page_size) == 0
# 1. TRT-LLM: block_num % (128 / page_size) == 0
# 2. Triton:
page table builder uses 64-index bursts, needs multiple of 64
# 2. Triton:
number of pages per block
trtllm_constraint
=
TRTLLM_BLOCK_CONSTRAINT
//
self
.
page_size
trtllm_constraint
=
TRTLLM_BLOCK_CONSTRAINT
//
self
.
page_size
constraint_lcm
=
math
.
lcm
(
trtllm_constraint
,
TRITON_PAD_NUM_PAGE_PER_BLOCK
)
triton_constraint
=
get_num_page_per_block_flashmla
(
self
.
page_size
)
constraint_lcm
=
math
.
lcm
(
trtllm_constraint
,
triton_constraint
)
if
blocks
%
constraint_lcm
!=
0
:
if
blocks
%
constraint_lcm
!=
0
:
blocks
=
triton
.
cdiv
(
blocks
,
constraint_lcm
)
*
constraint_lcm
blocks
=
triton
.
cdiv
(
blocks
,
constraint_lcm
)
*
constraint_lcm
...
@@ -336,7 +337,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -336,7 +337,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
block_kv_indices
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
max_blocks
,
max_blocks
,
NUM_PAGE_PER_BLOCK
=
TRITON_PAD_NUM_PAGE_PER_BLOCK
,
PAGED_SIZE
=
self
.
page_size
,
PAGED_SIZE
=
self
.
page_size
,
)
)
...
@@ -417,7 +417,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -417,7 +417,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
block_kv_indices
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
max_blocks_per_seq
,
max_blocks_per_seq
,
NUM_PAGE_PER_BLOCK
=
TRITON_PAD_NUM_PAGE_PER_BLOCK
,
PAGED_SIZE
=
self
.
page_size
,
PAGED_SIZE
=
self
.
page_size
,
)
)
...
@@ -504,7 +503,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -504,7 +503,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
metadata
.
block_kv_indices
,
metadata
.
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
metadata
.
block_kv_indices
.
shape
[
1
],
metadata
.
block_kv_indices
.
shape
[
1
],
NUM_PAGE_PER_BLOCK
=
TRITON_PAD_NUM_PAGE_PER_BLOCK
,
PAGED_SIZE
=
self
.
page_size
,
PAGED_SIZE
=
self
.
page_size
,
)
)
...
...
python/sglang/srt/layers/attention/utils.py
View file @
c23eda85
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
# Keep this in sync with the Triton kernel inside `create_flashmla_kv_indices_triton`.
_FLASHMLA_CREATE_KV_BLOCK_SIZE
=
4096
# Number of pages that the kernel writes per iteration.
FLASHMLA_CREATE_KV_BLOCK_SIZE_TRITON
=
tl
.
constexpr
(
_FLASHMLA_CREATE_KV_BLOCK_SIZE
)
# Exposed here so other Python modules can import it instead of hard-coding 64.
TRITON_PAD_NUM_PAGE_PER_BLOCK
=
64
@
triton
.
jit
@
triton
.
jit
...
@@ -46,6 +44,11 @@ def create_flashinfer_kv_indices_triton(
...
@@ -46,6 +44,11 @@ def create_flashinfer_kv_indices_triton(
tl
.
store
(
kv_indices_ptr
+
kv_indices_offset
+
offset
,
data
,
mask
=
mask
)
tl
.
store
(
kv_indices_ptr
+
kv_indices_offset
+
offset
,
data
,
mask
=
mask
)
def
get_num_page_per_block_flashmla
(
page_size
:
int
=
64
)
->
int
:
num_page_per_block
=
_FLASHMLA_CREATE_KV_BLOCK_SIZE
//
page_size
return
num_page_per_block
@
triton
.
jit
@
triton
.
jit
def
create_flashmla_kv_indices_triton
(
def
create_flashmla_kv_indices_triton
(
req_to_token_ptr
,
# [max_batch, max_context_len]
req_to_token_ptr
,
# [max_batch, max_context_len]
...
@@ -55,10 +58,11 @@ def create_flashmla_kv_indices_triton(
...
@@ -55,10 +58,11 @@ def create_flashmla_kv_indices_triton(
kv_indices_ptr
,
kv_indices_ptr
,
req_to_token_ptr_stride
:
tl
.
constexpr
,
req_to_token_ptr_stride
:
tl
.
constexpr
,
kv_indices_ptr_stride
:
tl
.
constexpr
,
kv_indices_ptr_stride
:
tl
.
constexpr
,
NUM_PAGE_PER_BLOCK
:
tl
.
constexpr
=
TRITON_PAD_NUM_PAGE_PER_BLOCK
,
PAGED_SIZE
:
tl
.
constexpr
=
64
,
PAGED_SIZE
:
tl
.
constexpr
=
64
,
):
):
BLOCK_SIZE
:
tl
.
constexpr
=
4096
NUM_PAGE_PER_BLOCK
:
tl
.
constexpr
=
(
FLASHMLA_CREATE_KV_BLOCK_SIZE_TRITON
//
PAGED_SIZE
)
pid
=
tl
.
program_id
(
axis
=
0
)
pid
=
tl
.
program_id
(
axis
=
0
)
# find the req pool idx, this is for batch to token
# find the req pool idx, this is for batch to token
...
@@ -73,7 +77,7 @@ def create_flashmla_kv_indices_triton(
...
@@ -73,7 +77,7 @@ def create_flashmla_kv_indices_triton(
kv_end
+=
tl
.
load
(
page_kernel_lens_ptr
+
pid
).
to
(
tl
.
int32
)
kv_end
+=
tl
.
load
(
page_kernel_lens_ptr
+
pid
).
to
(
tl
.
int32
)
num_paged
=
tl
.
cdiv
(
kv_end
-
kv_start
,
PAGED_SIZE
)
num_paged
=
tl
.
cdiv
(
kv_end
-
kv_start
,
PAGED_SIZE
)
num_pages_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
num_pages_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
FLASHMLA_CREATE_KV_
BLOCK_SIZE
_TRITON
)
for
i
in
range
(
num_pages_loop
):
for
i
in
range
(
num_pages_loop
):
# index into req_to_token_ptr needs to be int64
# index into req_to_token_ptr needs to be int64
...
...
python/sglang/test/attention/test_trtllm_mla_backend.py
View file @
c23eda85
...
@@ -16,10 +16,15 @@ from sglang.srt.layers.attention.trtllm_mla_backend import (
...
@@ -16,10 +16,15 @@ from sglang.srt.layers.attention.trtllm_mla_backend import (
TRTLLMMLABackend
,
TRTLLMMLABackend
,
TRTLLMMLADecodeMetadata
,
TRTLLMMLADecodeMetadata
,
)
)
from
sglang.srt.layers.attention.utils
import
TRITON_PAD_NUM_PAGE_PER_BLOCK
from
sglang.srt.layers.attention.utils
import
get_num_page_per_block_flashmla
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.mem_cache.memory_pool
import
MLATokenToKVPool
from
sglang.srt.mem_cache.memory_pool
import
MLATokenToKVPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.server_args
import
(
ServerArgs
,
get_global_server_args
,
set_global_server_args_for_scheduler
,
)
from
sglang.srt.utils
import
is_flashinfer_available
from
sglang.srt.utils
import
is_flashinfer_available
from
sglang.test.test_utils
import
CustomTestCase
from
sglang.test.test_utils
import
CustomTestCase
...
@@ -104,15 +109,15 @@ TEST_CASES = {
...
@@ -104,15 +109,15 @@ TEST_CASES = {
"page_size"
:
32
,
"page_size"
:
32
,
"description"
:
"Single FP16 vs reference"
,
"description"
:
"Single FP16 vs reference"
,
},
},
{
#
{
"name"
:
"single_fp8"
,
#
"name": "single_fp8",
"batch_size"
:
1
,
#
"batch_size": 1,
"max_seq_len"
:
64
,
#
"max_seq_len": 64,
"page_size"
:
64
,
#
"page_size": 64,
"tolerance"
:
1e-1
,
#
"tolerance": 1e-1,
"kv_cache_dtype"
:
torch
.
float8_e4m3fn
,
#
"kv_cache_dtype": torch.float8_e4m3fn,
"description"
:
"Single FP8 vs reference"
,
#
"description": "Single FP8 vs reference",
},
#
},
{
{
"name"
:
"batch_fp16"
,
"name"
:
"batch_fp16"
,
"batch_size"
:
32
,
"batch_size"
:
32
,
...
@@ -120,15 +125,15 @@ TEST_CASES = {
...
@@ -120,15 +125,15 @@ TEST_CASES = {
"page_size"
:
32
,
"page_size"
:
32
,
"description"
:
"Batch FP16 vs reference"
,
"description"
:
"Batch FP16 vs reference"
,
},
},
{
#
{
"name"
:
"batch_fp8"
,
#
"name": "batch_fp8",
"batch_size"
:
32
,
#
"batch_size": 32,
"max_seq_len"
:
64
,
#
"max_seq_len": 64,
"page_size"
:
64
,
#
"page_size": 64,
"tolerance"
:
1e-1
,
#
"tolerance": 1e-1,
"kv_cache_dtype"
:
torch
.
float8_e4m3fn
,
#
"kv_cache_dtype": torch.float8_e4m3fn,
"description"
:
"Batch FP8 vs reference"
,
#
"description": "Batch FP8 vs reference",
},
#
},
],
],
"page_size_consistency"
:
[
"page_size_consistency"
:
[
# Only 32 and 64 supported for now in flashinfer TRTLLM-GEN MLA kernel
# Only 32 and 64 supported for now in flashinfer TRTLLM-GEN MLA kernel
...
@@ -213,13 +218,7 @@ class MockModelRunner:
...
@@ -213,13 +218,7 @@ class MockModelRunner:
self
.
page_size
=
config
[
"page_size"
]
self
.
page_size
=
config
[
"page_size"
]
# Server args stub - needed by attention backends
# Server args stub - needed by attention backends
self
.
server_args
=
type
(
self
.
server_args
=
get_global_server_args
()
"ServerArgs"
,
(),
{
"enable_dp_attention"
:
False
,
# Default value for testing
},
)
# Model-config stub with MLA attributes
# Model-config stub with MLA attributes
self
.
model_config
=
type
(
self
.
model_config
=
type
(
...
@@ -320,6 +319,17 @@ def compare_outputs(trtllm_out, reference_out, tolerance=1e-2):
...
@@ -320,6 +319,17 @@ def compare_outputs(trtllm_out, reference_out, tolerance=1e-2):
class
TestTRTLLMMLA
(
CustomTestCase
):
class
TestTRTLLMMLA
(
CustomTestCase
):
"""Test suite for TRTLLM MLA backend with centralized configuration."""
"""Test suite for TRTLLM MLA backend with centralized configuration."""
@
classmethod
def
setUpClass
(
cls
):
"""Set up global server args for testing."""
server_args
=
ServerArgs
(
model_path
=
"dummy"
)
server_args
.
enable_dp_attention
=
False
set_global_server_args_for_scheduler
(
server_args
)
@
classmethod
def
tearDownClass
(
cls
):
pass
def
_merge_config
(
self
,
test_case
):
def
_merge_config
(
self
,
test_case
):
"""Merge test case with default configuration."""
"""Merge test case with default configuration."""
config
=
DEFAULT_CONFIG
.
copy
()
config
=
DEFAULT_CONFIG
.
copy
()
...
@@ -841,25 +851,17 @@ class TestTRTLLMMLA(CustomTestCase):
...
@@ -841,25 +851,17 @@ class TestTRTLLMMLA(CustomTestCase):
backend
.
init_forward_metadata
(
fb
)
backend
.
init_forward_metadata
(
fb
)
# Verify metadata exists
# Verify metadata exists
self
.
assertIsNotNone
(
backend
.
forward_metadata
)
self
.
assertIsNotNone
(
backend
.
forward_decode_metadata
)
self
.
assertIsInstance
(
backend
.
forward_metadata
,
TRTLLMMLADecodeMetadata
)
self
.
assertIsInstance
(
backend
.
forward_decode_metadata
,
TRTLLMMLADecodeMetadata
)
# Test metadata structure
# Test metadata structure
metadata
=
backend
.
forward_metadata
metadata
=
backend
.
forward_decode_metadata
self
.
assertIsNotNone
(
metadata
.
workspace
,
"Workspace should be allocated"
)
self
.
assertIsNotNone
(
self
.
assertIsNotNone
(
metadata
.
block_kv_indices
,
"Block KV indices should be created"
metadata
.
block_kv_indices
,
"Block KV indices should be created"
)
)
# Test workspace properties
self
.
assertEqual
(
metadata
.
workspace
.
device
.
type
,
"cuda"
)
self
.
assertEqual
(
metadata
.
workspace
.
dtype
,
torch
.
uint8
)
self
.
assertGreater
(
metadata
.
workspace
.
numel
(),
0
,
"Workspace should have non-zero size"
)
# Test block KV indices properties
# Test block KV indices properties
self
.
assertEqual
(
metadata
.
block_kv_indices
.
device
.
type
,
"cuda"
)
self
.
assertEqual
(
metadata
.
block_kv_indices
.
device
.
type
,
"cuda"
)
self
.
assertEqual
(
metadata
.
block_kv_indices
.
dtype
,
torch
.
int32
)
self
.
assertEqual
(
metadata
.
block_kv_indices
.
dtype
,
torch
.
int32
)
...
@@ -915,9 +917,10 @@ class TestTRTLLMMLA(CustomTestCase):
...
@@ -915,9 +917,10 @@ class TestTRTLLMMLA(CustomTestCase):
# Should satisfy TRT-LLM and Triton constraints
# Should satisfy TRT-LLM and Triton constraints
trtllm_constraint
=
128
//
scenario
[
"page_size"
]
trtllm_constraint
=
128
//
scenario
[
"page_size"
]
constraint
_lcm
=
math
.
lcm
(
triton_
constraint
=
get_num_page_per_block_flashmla
(
trtllm_constraint
,
TRITON_PAD_NUM_PAGE_PER_BLOCK
scenario
[
"page_size"
]
)
)
constraint_lcm
=
math
.
lcm
(
trtllm_constraint
,
triton_constraint
)
self
.
assertEqual
(
self
.
assertEqual
(
calculated_blocks
%
constraint_lcm
,
calculated_blocks
%
constraint_lcm
,
0
,
0
,
...
@@ -965,7 +968,7 @@ class TestTRTLLMMLA(CustomTestCase):
...
@@ -965,7 +968,7 @@ class TestTRTLLMMLA(CustomTestCase):
# Initialize metadata
# Initialize metadata
backend
.
init_forward_metadata
(
fb
)
backend
.
init_forward_metadata
(
fb
)
metadata
=
backend
.
forward_metadata
metadata
=
backend
.
forward_
decode_
metadata
# Verify KV indices structure
# Verify KV indices structure
block_kv_indices
=
metadata
.
block_kv_indices
block_kv_indices
=
metadata
.
block_kv_indices
...
@@ -1016,7 +1019,6 @@ class TestTRTLLMMLA(CustomTestCase):
...
@@ -1016,7 +1019,6 @@ class TestTRTLLMMLA(CustomTestCase):
# Verify CUDA graph buffers are allocated
# Verify CUDA graph buffers are allocated
self
.
assertIsNotNone
(
backend
.
decode_cuda_graph_kv_indices
)
self
.
assertIsNotNone
(
backend
.
decode_cuda_graph_kv_indices
)
self
.
assertIsNotNone
(
backend
.
decode_cuda_graph_workspace
)
# Test capture metadata
# Test capture metadata
seq_lens
=
torch
.
full
(
seq_lens
=
torch
.
full
(
...
@@ -1038,7 +1040,6 @@ class TestTRTLLMMLA(CustomTestCase):
...
@@ -1038,7 +1040,6 @@ class TestTRTLLMMLA(CustomTestCase):
self
.
assertIn
(
batch_size
,
backend
.
decode_cuda_graph_metadata
)
self
.
assertIn
(
batch_size
,
backend
.
decode_cuda_graph_metadata
)
capture_metadata
=
backend
.
decode_cuda_graph_metadata
[
batch_size
]
capture_metadata
=
backend
.
decode_cuda_graph_metadata
[
batch_size
]
self
.
assertIsNotNone
(
capture_metadata
.
workspace
)
self
.
assertIsNotNone
(
capture_metadata
.
block_kv_indices
)
self
.
assertIsNotNone
(
capture_metadata
.
block_kv_indices
)
# Test replay with different sequence lengths
# Test replay with different sequence lengths
...
@@ -1061,11 +1062,8 @@ class TestTRTLLMMLA(CustomTestCase):
...
@@ -1061,11 +1062,8 @@ class TestTRTLLMMLA(CustomTestCase):
)
)
# Verify replay updated the metadata
# Verify replay updated the metadata
replay_metadata
=
backend
.
forward_metadata
replay_metadata
=
backend
.
forward_
decode_
metadata
self
.
assertIsNotNone
(
replay_metadata
)
self
.
assertIsNotNone
(
replay_metadata
)
self
.
assertEqual
(
replay_metadata
.
workspace
.
data_ptr
(),
capture_metadata
.
workspace
.
data_ptr
()
)
def
test_metadata_consistency_across_calls
(
self
):
def
test_metadata_consistency_across_calls
(
self
):
"""Test metadata consistency across multiple forward calls."""
"""Test metadata consistency across multiple forward calls."""
...
@@ -1083,7 +1081,7 @@ class TestTRTLLMMLA(CustomTestCase):
...
@@ -1083,7 +1081,7 @@ class TestTRTLLMMLA(CustomTestCase):
config
[
"batch_size"
],
seq_lens_1
,
backend
,
model_runner
,
config
config
[
"batch_size"
],
seq_lens_1
,
backend
,
model_runner
,
config
)
)
backend
.
init_forward_metadata
(
fb_1
)
backend
.
init_forward_metadata
(
fb_1
)
metadata_1
=
backend
.
forward_metadata
metadata_1
=
backend
.
forward_
decode_
metadata
# Second call with same sequence lengths
# Second call with same sequence lengths
seq_lens_2
=
torch
.
tensor
([
32
,
48
],
device
=
config
[
"device"
])
seq_lens_2
=
torch
.
tensor
([
32
,
48
],
device
=
config
[
"device"
])
...
@@ -1091,10 +1089,9 @@ class TestTRTLLMMLA(CustomTestCase):
...
@@ -1091,10 +1089,9 @@ class TestTRTLLMMLA(CustomTestCase):
config
[
"batch_size"
],
seq_lens_2
,
backend
,
model_runner
,
config
config
[
"batch_size"
],
seq_lens_2
,
backend
,
model_runner
,
config
)
)
backend
.
init_forward_metadata
(
fb_2
)
backend
.
init_forward_metadata
(
fb_2
)
metadata_2
=
backend
.
forward_metadata
metadata_2
=
backend
.
forward_
decode_
metadata
# Metadata structure should be consistent
# Metadata structure should be consistent
self
.
assertEqual
(
metadata_1
.
workspace
.
shape
,
metadata_2
.
workspace
.
shape
)
self
.
assertEqual
(
self
.
assertEqual
(
metadata_1
.
block_kv_indices
.
shape
,
metadata_2
.
block_kv_indices
.
shape
metadata_1
.
block_kv_indices
.
shape
,
metadata_2
.
block_kv_indices
.
shape
)
)
...
@@ -1105,10 +1102,9 @@ class TestTRTLLMMLA(CustomTestCase):
...
@@ -1105,10 +1102,9 @@ class TestTRTLLMMLA(CustomTestCase):
config
[
"batch_size"
],
seq_lens_3
,
backend
,
model_runner
,
config
config
[
"batch_size"
],
seq_lens_3
,
backend
,
model_runner
,
config
)
)
backend
.
init_forward_metadata
(
fb_3
)
backend
.
init_forward_metadata
(
fb_3
)
metadata_3
=
backend
.
forward_metadata
metadata_3
=
backend
.
forward_
decode_
metadata
# Should still have valid structure
# Should still have valid structure
self
.
assertIsNotNone
(
metadata_3
.
workspace
)
self
.
assertIsNotNone
(
metadata_3
.
block_kv_indices
)
self
.
assertIsNotNone
(
metadata_3
.
block_kv_indices
)
self
.
assertEqual
(
metadata_3
.
block_kv_indices
.
shape
[
0
],
config
[
"batch_size"
])
self
.
assertEqual
(
metadata_3
.
block_kv_indices
.
shape
[
0
],
config
[
"batch_size"
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment