Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
c23eda85
"vscode:/vscode.git/clone" did not exist on "e4cbe9ee31995e99dee2d8a4f58e28bb965b633d"
Unverified
Commit
c23eda85
authored
Oct 22, 2025
by
yinghui
Committed by
GitHub
Oct 22, 2025
Browse files
Fix incorrect KV indices creation when page_size=32 in TRTLLM MLA backend (#11985)
parent
138ff231
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
70 additions
and
73 deletions
+70
-73
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+6
-7
python/sglang/srt/layers/attention/trtllm_mla_backend.py
python/sglang/srt/layers/attention/trtllm_mla_backend.py
+4
-6
python/sglang/srt/layers/attention/utils.py
python/sglang/srt/layers/attention/utils.py
+11
-7
python/sglang/test/attention/test_trtllm_mla_backend.py
python/sglang/test/attention/test_trtllm_mla_backend.py
+49
-53
No files found.
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
c23eda85
...
@@ -9,19 +9,12 @@ and uses BatchMLAPaged wrapper for decoding.
...
@@ -9,19 +9,12 @@ and uses BatchMLAPaged wrapper for decoding.
More details can be found in https://docs.flashinfer.ai/api/mla.html
More details can be found in https://docs.flashinfer.ai/api/mla.html
"""
"""
import
os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
partial
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Union
import
torch
import
torch
if
os
.
environ
[
"SGLANG_ENABLE_TORCH_COMPILE"
]
==
"1"
:
import
logging
torch
.
_logging
.
set_logs
(
dynamo
=
logging
.
ERROR
)
torch
.
_dynamo
.
config
.
suppress_errors
=
True
from
sglang.srt.environ
import
envs
from
sglang.srt.environ
import
envs
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
(
from
sglang.srt.layers.attention.flashinfer_backend
import
(
...
@@ -45,6 +38,12 @@ if TYPE_CHECKING:
...
@@ -45,6 +38,12 @@ if TYPE_CHECKING:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecInput
from
sglang.srt.speculative.spec_info
import
SpecInput
if
envs
.
SGLANG_ENABLE_TORCH_COMPILE
.
get
():
import
logging
torch
.
_logging
.
set_logs
(
dynamo
=
logging
.
ERROR
)
torch
.
_dynamo
.
config
.
suppress_errors
=
True
if
is_flashinfer_available
():
if
is_flashinfer_available
():
from
flashinfer
import
(
from
flashinfer
import
(
BatchMLAPagedAttentionWrapper
,
BatchMLAPagedAttentionWrapper
,
...
...
python/sglang/srt/layers/attention/trtllm_mla_backend.py
View file @
c23eda85
...
@@ -17,8 +17,8 @@ from sglang.srt.layers.attention.flashinfer_mla_backend import (
...
@@ -17,8 +17,8 @@ from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAMultiStepDraftBackend
,
FlashInferMLAMultiStepDraftBackend
,
)
)
from
sglang.srt.layers.attention.utils
import
(
from
sglang.srt.layers.attention.utils
import
(
TRITON_PAD_NUM_PAGE_PER_BLOCK
,
create_flashmla_kv_indices_triton
,
create_flashmla_kv_indices_triton
,
get_num_page_per_block_flashmla
,
)
)
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
...
@@ -295,9 +295,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -295,9 +295,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
# Apply dual constraints (take LCM to satisfy both):
# Apply dual constraints (take LCM to satisfy both):
# 1. TRT-LLM: block_num % (128 / page_size) == 0
# 1. TRT-LLM: block_num % (128 / page_size) == 0
# 2. Triton:
page table builder uses 64-index bursts, needs multiple of 64
# 2. Triton:
number of pages per block
trtllm_constraint
=
TRTLLM_BLOCK_CONSTRAINT
//
self
.
page_size
trtllm_constraint
=
TRTLLM_BLOCK_CONSTRAINT
//
self
.
page_size
constraint_lcm
=
math
.
lcm
(
trtllm_constraint
,
TRITON_PAD_NUM_PAGE_PER_BLOCK
)
triton_constraint
=
get_num_page_per_block_flashmla
(
self
.
page_size
)
constraint_lcm
=
math
.
lcm
(
trtllm_constraint
,
triton_constraint
)
if
blocks
%
constraint_lcm
!=
0
:
if
blocks
%
constraint_lcm
!=
0
:
blocks
=
triton
.
cdiv
(
blocks
,
constraint_lcm
)
*
constraint_lcm
blocks
=
triton
.
cdiv
(
blocks
,
constraint_lcm
)
*
constraint_lcm
...
@@ -336,7 +337,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -336,7 +337,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
block_kv_indices
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
max_blocks
,
max_blocks
,
NUM_PAGE_PER_BLOCK
=
TRITON_PAD_NUM_PAGE_PER_BLOCK
,
PAGED_SIZE
=
self
.
page_size
,
PAGED_SIZE
=
self
.
page_size
,
)
)
...
@@ -417,7 +417,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -417,7 +417,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
block_kv_indices
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
max_blocks_per_seq
,
max_blocks_per_seq
,
NUM_PAGE_PER_BLOCK
=
TRITON_PAD_NUM_PAGE_PER_BLOCK
,
PAGED_SIZE
=
self
.
page_size
,
PAGED_SIZE
=
self
.
page_size
,
)
)
...
@@ -504,7 +503,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -504,7 +503,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
metadata
.
block_kv_indices
,
metadata
.
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
metadata
.
block_kv_indices
.
shape
[
1
],
metadata
.
block_kv_indices
.
shape
[
1
],
NUM_PAGE_PER_BLOCK
=
TRITON_PAD_NUM_PAGE_PER_BLOCK
,
PAGED_SIZE
=
self
.
page_size
,
PAGED_SIZE
=
self
.
page_size
,
)
)
...
...
python/sglang/srt/layers/attention/utils.py
View file @
c23eda85
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
# Keep this in sync with the Triton kernel inside `create_flashmla_kv_indices_triton`.
_FLASHMLA_CREATE_KV_BLOCK_SIZE
=
4096
# Number of pages that the kernel writes per iteration.
FLASHMLA_CREATE_KV_BLOCK_SIZE_TRITON
=
tl
.
constexpr
(
_FLASHMLA_CREATE_KV_BLOCK_SIZE
)
# Exposed here so other Python modules can import it instead of hard-coding 64.
TRITON_PAD_NUM_PAGE_PER_BLOCK
=
64
@
triton
.
jit
@
triton
.
jit
...
@@ -46,6 +44,11 @@ def create_flashinfer_kv_indices_triton(
...
@@ -46,6 +44,11 @@ def create_flashinfer_kv_indices_triton(
tl
.
store
(
kv_indices_ptr
+
kv_indices_offset
+
offset
,
data
,
mask
=
mask
)
tl
.
store
(
kv_indices_ptr
+
kv_indices_offset
+
offset
,
data
,
mask
=
mask
)
def
get_num_page_per_block_flashmla
(
page_size
:
int
=
64
)
->
int
:
num_page_per_block
=
_FLASHMLA_CREATE_KV_BLOCK_SIZE
//
page_size
return
num_page_per_block
@
triton
.
jit
@
triton
.
jit
def
create_flashmla_kv_indices_triton
(
def
create_flashmla_kv_indices_triton
(
req_to_token_ptr
,
# [max_batch, max_context_len]
req_to_token_ptr
,
# [max_batch, max_context_len]
...
@@ -55,10 +58,11 @@ def create_flashmla_kv_indices_triton(
...
@@ -55,10 +58,11 @@ def create_flashmla_kv_indices_triton(
kv_indices_ptr
,
kv_indices_ptr
,
req_to_token_ptr_stride
:
tl
.
constexpr
,
req_to_token_ptr_stride
:
tl
.
constexpr
,
kv_indices_ptr_stride
:
tl
.
constexpr
,
kv_indices_ptr_stride
:
tl
.
constexpr
,
NUM_PAGE_PER_BLOCK
:
tl
.
constexpr
=
TRITON_PAD_NUM_PAGE_PER_BLOCK
,
PAGED_SIZE
:
tl
.
constexpr
=
64
,
PAGED_SIZE
:
tl
.
constexpr
=
64
,
):
):
BLOCK_SIZE
:
tl
.
constexpr
=
4096
NUM_PAGE_PER_BLOCK
:
tl
.
constexpr
=
(
FLASHMLA_CREATE_KV_BLOCK_SIZE_TRITON
//
PAGED_SIZE
)
pid
=
tl
.
program_id
(
axis
=
0
)
pid
=
tl
.
program_id
(
axis
=
0
)
# find the req pool idx, this is for batch to token
# find the req pool idx, this is for batch to token
...
@@ -73,7 +77,7 @@ def create_flashmla_kv_indices_triton(
...
@@ -73,7 +77,7 @@ def create_flashmla_kv_indices_triton(
kv_end
+=
tl
.
load
(
page_kernel_lens_ptr
+
pid
).
to
(
tl
.
int32
)
kv_end
+=
tl
.
load
(
page_kernel_lens_ptr
+
pid
).
to
(
tl
.
int32
)
num_paged
=
tl
.
cdiv
(
kv_end
-
kv_start
,
PAGED_SIZE
)
num_paged
=
tl
.
cdiv
(
kv_end
-
kv_start
,
PAGED_SIZE
)
num_pages_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
num_pages_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
FLASHMLA_CREATE_KV_
BLOCK_SIZE
_TRITON
)
for
i
in
range
(
num_pages_loop
):
for
i
in
range
(
num_pages_loop
):
# index into req_to_token_ptr needs to be int64
# index into req_to_token_ptr needs to be int64
...
...
python/sglang/test/attention/test_trtllm_mla_backend.py
View file @
c23eda85
...
@@ -16,10 +16,15 @@ from sglang.srt.layers.attention.trtllm_mla_backend import (
...
@@ -16,10 +16,15 @@ from sglang.srt.layers.attention.trtllm_mla_backend import (
TRTLLMMLABackend
,
TRTLLMMLABackend
,
TRTLLMMLADecodeMetadata
,
TRTLLMMLADecodeMetadata
,
)
)
from
sglang.srt.layers.attention.utils
import
TRITON_PAD_NUM_PAGE_PER_BLOCK
from
sglang.srt.layers.attention.utils
import
get_num_page_per_block_flashmla
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.mem_cache.memory_pool
import
MLATokenToKVPool
from
sglang.srt.mem_cache.memory_pool
import
MLATokenToKVPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.server_args
import
(
ServerArgs
,
get_global_server_args
,
set_global_server_args_for_scheduler
,
)
from
sglang.srt.utils
import
is_flashinfer_available
from
sglang.srt.utils
import
is_flashinfer_available
from
sglang.test.test_utils
import
CustomTestCase
from
sglang.test.test_utils
import
CustomTestCase
...
@@ -104,15 +109,15 @@ TEST_CASES = {
...
@@ -104,15 +109,15 @@ TEST_CASES = {
"page_size"
:
32
,
"page_size"
:
32
,
"description"
:
"Single FP16 vs reference"
,
"description"
:
"Single FP16 vs reference"
,
},
},
{
#
{
"name"
:
"single_fp8"
,
#
"name": "single_fp8",
"batch_size"
:
1
,
#
"batch_size": 1,
"max_seq_len"
:
64
,
#
"max_seq_len": 64,
"page_size"
:
64
,
#
"page_size": 64,
"tolerance"
:
1e-1
,
#
"tolerance": 1e-1,
"kv_cache_dtype"
:
torch
.
float8_e4m3fn
,
#
"kv_cache_dtype": torch.float8_e4m3fn,
"description"
:
"Single FP8 vs reference"
,
#
"description": "Single FP8 vs reference",
},
#
},
{
{
"name"
:
"batch_fp16"
,
"name"
:
"batch_fp16"
,
"batch_size"
:
32
,
"batch_size"
:
32
,
...
@@ -120,15 +125,15 @@ TEST_CASES = {
...
@@ -120,15 +125,15 @@ TEST_CASES = {
"page_size"
:
32
,
"page_size"
:
32
,
"description"
:
"Batch FP16 vs reference"
,
"description"
:
"Batch FP16 vs reference"
,
},
},
{
#
{
"name"
:
"batch_fp8"
,
#
"name": "batch_fp8",
"batch_size"
:
32
,
#
"batch_size": 32,
"max_seq_len"
:
64
,
#
"max_seq_len": 64,
"page_size"
:
64
,
#
"page_size": 64,
"tolerance"
:
1e-1
,
#
"tolerance": 1e-1,
"kv_cache_dtype"
:
torch
.
float8_e4m3fn
,
#
"kv_cache_dtype": torch.float8_e4m3fn,
"description"
:
"Batch FP8 vs reference"
,
#
"description": "Batch FP8 vs reference",
},
#
},
],
],
"page_size_consistency"
:
[
"page_size_consistency"
:
[
# Only 32 and 64 supported for now in flashinfer TRTLLM-GEN MLA kernel
# Only 32 and 64 supported for now in flashinfer TRTLLM-GEN MLA kernel
...
@@ -213,13 +218,7 @@ class MockModelRunner:
...
@@ -213,13 +218,7 @@ class MockModelRunner:
self
.
page_size
=
config
[
"page_size"
]
self
.
page_size
=
config
[
"page_size"
]
# Server args stub - needed by attention backends
# Server args stub - needed by attention backends
self
.
server_args
=
type
(
self
.
server_args
=
get_global_server_args
()
"ServerArgs"
,
(),
{
"enable_dp_attention"
:
False
,
# Default value for testing
},
)
# Model-config stub with MLA attributes
# Model-config stub with MLA attributes
self
.
model_config
=
type
(
self
.
model_config
=
type
(
...
@@ -320,6 +319,17 @@ def compare_outputs(trtllm_out, reference_out, tolerance=1e-2):
...
@@ -320,6 +319,17 @@ def compare_outputs(trtllm_out, reference_out, tolerance=1e-2):
class
TestTRTLLMMLA
(
CustomTestCase
):
class
TestTRTLLMMLA
(
CustomTestCase
):
"""Test suite for TRTLLM MLA backend with centralized configuration."""
"""Test suite for TRTLLM MLA backend with centralized configuration."""
@
classmethod
def
setUpClass
(
cls
):
"""Set up global server args for testing."""
server_args
=
ServerArgs
(
model_path
=
"dummy"
)
server_args
.
enable_dp_attention
=
False
set_global_server_args_for_scheduler
(
server_args
)
@
classmethod
def
tearDownClass
(
cls
):
pass
def
_merge_config
(
self
,
test_case
):
def
_merge_config
(
self
,
test_case
):
"""Merge test case with default configuration."""
"""Merge test case with default configuration."""
config
=
DEFAULT_CONFIG
.
copy
()
config
=
DEFAULT_CONFIG
.
copy
()
...
@@ -841,25 +851,17 @@ class TestTRTLLMMLA(CustomTestCase):
...
@@ -841,25 +851,17 @@ class TestTRTLLMMLA(CustomTestCase):
backend
.
init_forward_metadata
(
fb
)
backend
.
init_forward_metadata
(
fb
)
# Verify metadata exists
# Verify metadata exists
self
.
assertIsNotNone
(
backend
.
forward_metadata
)
self
.
assertIsNotNone
(
backend
.
forward_decode_metadata
)
self
.
assertIsInstance
(
backend
.
forward_metadata
,
TRTLLMMLADecodeMetadata
)
self
.
assertIsInstance
(
backend
.
forward_decode_metadata
,
TRTLLMMLADecodeMetadata
)
# Test metadata structure
# Test metadata structure
metadata
=
backend
.
forward_metadata
metadata
=
backend
.
forward_decode_metadata
self
.
assertIsNotNone
(
metadata
.
workspace
,
"Workspace should be allocated"
)
self
.
assertIsNotNone
(
self
.
assertIsNotNone
(
metadata
.
block_kv_indices
,
"Block KV indices should be created"
metadata
.
block_kv_indices
,
"Block KV indices should be created"
)
)
# Test workspace properties
self
.
assertEqual
(
metadata
.
workspace
.
device
.
type
,
"cuda"
)
self
.
assertEqual
(
metadata
.
workspace
.
dtype
,
torch
.
uint8
)
self
.
assertGreater
(
metadata
.
workspace
.
numel
(),
0
,
"Workspace should have non-zero size"
)
# Test block KV indices properties
# Test block KV indices properties
self
.
assertEqual
(
metadata
.
block_kv_indices
.
device
.
type
,
"cuda"
)
self
.
assertEqual
(
metadata
.
block_kv_indices
.
device
.
type
,
"cuda"
)
self
.
assertEqual
(
metadata
.
block_kv_indices
.
dtype
,
torch
.
int32
)
self
.
assertEqual
(
metadata
.
block_kv_indices
.
dtype
,
torch
.
int32
)
...
@@ -915,9 +917,10 @@ class TestTRTLLMMLA(CustomTestCase):
...
@@ -915,9 +917,10 @@ class TestTRTLLMMLA(CustomTestCase):
# Should satisfy TRT-LLM and Triton constraints
# Should satisfy TRT-LLM and Triton constraints
trtllm_constraint
=
128
//
scenario
[
"page_size"
]
trtllm_constraint
=
128
//
scenario
[
"page_size"
]
constraint
_lcm
=
math
.
lcm
(
triton_
constraint
=
get_num_page_per_block_flashmla
(
trtllm_constraint
,
TRITON_PAD_NUM_PAGE_PER_BLOCK
scenario
[
"page_size"
]
)
)
constraint_lcm
=
math
.
lcm
(
trtllm_constraint
,
triton_constraint
)
self
.
assertEqual
(
self
.
assertEqual
(
calculated_blocks
%
constraint_lcm
,
calculated_blocks
%
constraint_lcm
,
0
,
0
,
...
@@ -965,7 +968,7 @@ class TestTRTLLMMLA(CustomTestCase):
...
@@ -965,7 +968,7 @@ class TestTRTLLMMLA(CustomTestCase):
# Initialize metadata
# Initialize metadata
backend
.
init_forward_metadata
(
fb
)
backend
.
init_forward_metadata
(
fb
)
metadata
=
backend
.
forward_metadata
metadata
=
backend
.
forward_
decode_
metadata
# Verify KV indices structure
# Verify KV indices structure
block_kv_indices
=
metadata
.
block_kv_indices
block_kv_indices
=
metadata
.
block_kv_indices
...
@@ -1016,7 +1019,6 @@ class TestTRTLLMMLA(CustomTestCase):
...
@@ -1016,7 +1019,6 @@ class TestTRTLLMMLA(CustomTestCase):
# Verify CUDA graph buffers are allocated
# Verify CUDA graph buffers are allocated
self
.
assertIsNotNone
(
backend
.
decode_cuda_graph_kv_indices
)
self
.
assertIsNotNone
(
backend
.
decode_cuda_graph_kv_indices
)
self
.
assertIsNotNone
(
backend
.
decode_cuda_graph_workspace
)
# Test capture metadata
# Test capture metadata
seq_lens
=
torch
.
full
(
seq_lens
=
torch
.
full
(
...
@@ -1038,7 +1040,6 @@ class TestTRTLLMMLA(CustomTestCase):
...
@@ -1038,7 +1040,6 @@ class TestTRTLLMMLA(CustomTestCase):
self
.
assertIn
(
batch_size
,
backend
.
decode_cuda_graph_metadata
)
self
.
assertIn
(
batch_size
,
backend
.
decode_cuda_graph_metadata
)
capture_metadata
=
backend
.
decode_cuda_graph_metadata
[
batch_size
]
capture_metadata
=
backend
.
decode_cuda_graph_metadata
[
batch_size
]
self
.
assertIsNotNone
(
capture_metadata
.
workspace
)
self
.
assertIsNotNone
(
capture_metadata
.
block_kv_indices
)
self
.
assertIsNotNone
(
capture_metadata
.
block_kv_indices
)
# Test replay with different sequence lengths
# Test replay with different sequence lengths
...
@@ -1061,11 +1062,8 @@ class TestTRTLLMMLA(CustomTestCase):
...
@@ -1061,11 +1062,8 @@ class TestTRTLLMMLA(CustomTestCase):
)
)
# Verify replay updated the metadata
# Verify replay updated the metadata
replay_metadata
=
backend
.
forward_metadata
replay_metadata
=
backend
.
forward_
decode_
metadata
self
.
assertIsNotNone
(
replay_metadata
)
self
.
assertIsNotNone
(
replay_metadata
)
self
.
assertEqual
(
replay_metadata
.
workspace
.
data_ptr
(),
capture_metadata
.
workspace
.
data_ptr
()
)
def
test_metadata_consistency_across_calls
(
self
):
def
test_metadata_consistency_across_calls
(
self
):
"""Test metadata consistency across multiple forward calls."""
"""Test metadata consistency across multiple forward calls."""
...
@@ -1083,7 +1081,7 @@ class TestTRTLLMMLA(CustomTestCase):
...
@@ -1083,7 +1081,7 @@ class TestTRTLLMMLA(CustomTestCase):
config
[
"batch_size"
],
seq_lens_1
,
backend
,
model_runner
,
config
config
[
"batch_size"
],
seq_lens_1
,
backend
,
model_runner
,
config
)
)
backend
.
init_forward_metadata
(
fb_1
)
backend
.
init_forward_metadata
(
fb_1
)
metadata_1
=
backend
.
forward_metadata
metadata_1
=
backend
.
forward_
decode_
metadata
# Second call with same sequence lengths
# Second call with same sequence lengths
seq_lens_2
=
torch
.
tensor
([
32
,
48
],
device
=
config
[
"device"
])
seq_lens_2
=
torch
.
tensor
([
32
,
48
],
device
=
config
[
"device"
])
...
@@ -1091,10 +1089,9 @@ class TestTRTLLMMLA(CustomTestCase):
...
@@ -1091,10 +1089,9 @@ class TestTRTLLMMLA(CustomTestCase):
config
[
"batch_size"
],
seq_lens_2
,
backend
,
model_runner
,
config
config
[
"batch_size"
],
seq_lens_2
,
backend
,
model_runner
,
config
)
)
backend
.
init_forward_metadata
(
fb_2
)
backend
.
init_forward_metadata
(
fb_2
)
metadata_2
=
backend
.
forward_metadata
metadata_2
=
backend
.
forward_
decode_
metadata
# Metadata structure should be consistent
# Metadata structure should be consistent
self
.
assertEqual
(
metadata_1
.
workspace
.
shape
,
metadata_2
.
workspace
.
shape
)
self
.
assertEqual
(
self
.
assertEqual
(
metadata_1
.
block_kv_indices
.
shape
,
metadata_2
.
block_kv_indices
.
shape
metadata_1
.
block_kv_indices
.
shape
,
metadata_2
.
block_kv_indices
.
shape
)
)
...
@@ -1105,10 +1102,9 @@ class TestTRTLLMMLA(CustomTestCase):
...
@@ -1105,10 +1102,9 @@ class TestTRTLLMMLA(CustomTestCase):
config
[
"batch_size"
],
seq_lens_3
,
backend
,
model_runner
,
config
config
[
"batch_size"
],
seq_lens_3
,
backend
,
model_runner
,
config
)
)
backend
.
init_forward_metadata
(
fb_3
)
backend
.
init_forward_metadata
(
fb_3
)
metadata_3
=
backend
.
forward_metadata
metadata_3
=
backend
.
forward_
decode_
metadata
# Should still have valid structure
# Should still have valid structure
self
.
assertIsNotNone
(
metadata_3
.
workspace
)
self
.
assertIsNotNone
(
metadata_3
.
block_kv_indices
)
self
.
assertIsNotNone
(
metadata_3
.
block_kv_indices
)
self
.
assertEqual
(
metadata_3
.
block_kv_indices
.
shape
[
0
],
config
[
"batch_size"
])
self
.
assertEqual
(
metadata_3
.
block_kv_indices
.
shape
[
0
],
config
[
"batch_size"
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment