Unverified Commit c23eda85 authored by yinghui's avatar yinghui Committed by GitHub
Browse files

Fix incorrect KV indices creation when page_size=32 in TRTLLM MLA backend (#11985)

parent 138ff231
......@@ -9,19 +9,12 @@ and uses BatchMLAPaged wrapper for decoding.
More details can be found in https://docs.flashinfer.ai/api/mla.html
"""
import os
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Callable, Optional, Union
import torch
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
import logging
torch._logging.set_logs(dynamo=logging.ERROR)
torch._dynamo.config.suppress_errors = True
from sglang.srt.environ import envs
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.flashinfer_backend import (
......@@ -45,6 +38,12 @@ if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInput
if envs.SGLANG_ENABLE_TORCH_COMPILE.get():
import logging
torch._logging.set_logs(dynamo=logging.ERROR)
torch._dynamo.config.suppress_errors = True
if is_flashinfer_available():
from flashinfer import (
BatchMLAPagedAttentionWrapper,
......
......@@ -17,8 +17,8 @@ from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAMultiStepDraftBackend,
)
from sglang.srt.layers.attention.utils import (
TRITON_PAD_NUM_PAGE_PER_BLOCK,
create_flashmla_kv_indices_triton,
get_num_page_per_block_flashmla,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
......@@ -295,9 +295,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
# Apply dual constraints (take LCM to satisfy both):
# 1. TRT-LLM: block_num % (128 / page_size) == 0
# 2. Triton: page table builder uses 64-index bursts, needs multiple of 64
# 2. Triton: number of pages per block
trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size
constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK)
triton_constraint = get_num_page_per_block_flashmla(self.page_size)
constraint_lcm = math.lcm(trtllm_constraint, triton_constraint)
if blocks % constraint_lcm != 0:
blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm
......@@ -336,7 +337,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
block_kv_indices,
self.req_to_token.stride(0),
max_blocks,
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
PAGED_SIZE=self.page_size,
)
......@@ -417,7 +417,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
block_kv_indices,
self.req_to_token.stride(0),
max_blocks_per_seq,
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
PAGED_SIZE=self.page_size,
)
......@@ -504,7 +503,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
metadata.block_kv_indices,
self.req_to_token.stride(0),
metadata.block_kv_indices.shape[1],
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
PAGED_SIZE=self.page_size,
)
......
import triton
import triton.language as tl
# Keep this in sync with the Triton kernel inside `create_flashmla_kv_indices_triton`.
# Number of pages that the kernel writes per iteration.
# Exposed here so other Python modules can import it instead of hard-coding 64.
TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
_FLASHMLA_CREATE_KV_BLOCK_SIZE = 4096
FLASHMLA_CREATE_KV_BLOCK_SIZE_TRITON = tl.constexpr(_FLASHMLA_CREATE_KV_BLOCK_SIZE)
@triton.jit
......@@ -46,6 +44,11 @@ def create_flashinfer_kv_indices_triton(
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
def get_num_page_per_block_flashmla(page_size: int = 64) -> int:
num_page_per_block = _FLASHMLA_CREATE_KV_BLOCK_SIZE // page_size
return num_page_per_block
@triton.jit
def create_flashmla_kv_indices_triton(
req_to_token_ptr, # [max_batch, max_context_len]
......@@ -55,10 +58,11 @@ def create_flashmla_kv_indices_triton(
kv_indices_ptr,
req_to_token_ptr_stride: tl.constexpr,
kv_indices_ptr_stride: tl.constexpr,
NUM_PAGE_PER_BLOCK: tl.constexpr = TRITON_PAD_NUM_PAGE_PER_BLOCK,
PAGED_SIZE: tl.constexpr = 64,
):
BLOCK_SIZE: tl.constexpr = 4096
NUM_PAGE_PER_BLOCK: tl.constexpr = (
FLASHMLA_CREATE_KV_BLOCK_SIZE_TRITON // PAGED_SIZE
)
pid = tl.program_id(axis=0)
# find the req pool idx, this is for batch to token
......@@ -73,7 +77,7 @@ def create_flashmla_kv_indices_triton(
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
num_paged = tl.cdiv(kv_end - kv_start, PAGED_SIZE)
num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
num_pages_loop = tl.cdiv(kv_end - kv_start, FLASHMLA_CREATE_KV_BLOCK_SIZE_TRITON)
for i in range(num_pages_loop):
# index into req_to_token_ptr needs to be int64
......
......@@ -16,10 +16,15 @@ from sglang.srt.layers.attention.trtllm_mla_backend import (
TRTLLMMLABackend,
TRTLLMMLADecodeMetadata,
)
from sglang.srt.layers.attention.utils import TRITON_PAD_NUM_PAGE_PER_BLOCK
from sglang.srt.layers.attention.utils import get_num_page_per_block_flashmla
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import (
ServerArgs,
get_global_server_args,
set_global_server_args_for_scheduler,
)
from sglang.srt.utils import is_flashinfer_available
from sglang.test.test_utils import CustomTestCase
......@@ -104,15 +109,15 @@ TEST_CASES = {
"page_size": 32,
"description": "Single FP16 vs reference",
},
{
"name": "single_fp8",
"batch_size": 1,
"max_seq_len": 64,
"page_size": 64,
"tolerance": 1e-1,
"kv_cache_dtype": torch.float8_e4m3fn,
"description": "Single FP8 vs reference",
},
# {
# "name": "single_fp8",
# "batch_size": 1,
# "max_seq_len": 64,
# "page_size": 64,
# "tolerance": 1e-1,
# "kv_cache_dtype": torch.float8_e4m3fn,
# "description": "Single FP8 vs reference",
# },
{
"name": "batch_fp16",
"batch_size": 32,
......@@ -120,15 +125,15 @@ TEST_CASES = {
"page_size": 32,
"description": "Batch FP16 vs reference",
},
{
"name": "batch_fp8",
"batch_size": 32,
"max_seq_len": 64,
"page_size": 64,
"tolerance": 1e-1,
"kv_cache_dtype": torch.float8_e4m3fn,
"description": "Batch FP8 vs reference",
},
# {
# "name": "batch_fp8",
# "batch_size": 32,
# "max_seq_len": 64,
# "page_size": 64,
# "tolerance": 1e-1,
# "kv_cache_dtype": torch.float8_e4m3fn,
# "description": "Batch FP8 vs reference",
# },
],
"page_size_consistency": [
# Only 32 and 64 supported for now in flashinfer TRTLLM-GEN MLA kernel
......@@ -213,13 +218,7 @@ class MockModelRunner:
self.page_size = config["page_size"]
# Server args stub - needed by attention backends
self.server_args = type(
"ServerArgs",
(),
{
"enable_dp_attention": False, # Default value for testing
},
)
self.server_args = get_global_server_args()
# Model-config stub with MLA attributes
self.model_config = type(
......@@ -320,6 +319,17 @@ def compare_outputs(trtllm_out, reference_out, tolerance=1e-2):
class TestTRTLLMMLA(CustomTestCase):
"""Test suite for TRTLLM MLA backend with centralized configuration."""
@classmethod
def setUpClass(cls):
"""Set up global server args for testing."""
server_args = ServerArgs(model_path="dummy")
server_args.enable_dp_attention = False
set_global_server_args_for_scheduler(server_args)
@classmethod
def tearDownClass(cls):
pass
def _merge_config(self, test_case):
"""Merge test case with default configuration."""
config = DEFAULT_CONFIG.copy()
......@@ -841,25 +851,17 @@ class TestTRTLLMMLA(CustomTestCase):
backend.init_forward_metadata(fb)
# Verify metadata exists
self.assertIsNotNone(backend.forward_metadata)
self.assertIsInstance(backend.forward_metadata, TRTLLMMLADecodeMetadata)
self.assertIsNotNone(backend.forward_decode_metadata)
self.assertIsInstance(
backend.forward_decode_metadata, TRTLLMMLADecodeMetadata
)
# Test metadata structure
metadata = backend.forward_metadata
self.assertIsNotNone(
metadata.workspace, "Workspace should be allocated"
)
metadata = backend.forward_decode_metadata
self.assertIsNotNone(
metadata.block_kv_indices, "Block KV indices should be created"
)
# Test workspace properties
self.assertEqual(metadata.workspace.device.type, "cuda")
self.assertEqual(metadata.workspace.dtype, torch.uint8)
self.assertGreater(
metadata.workspace.numel(), 0, "Workspace should have non-zero size"
)
# Test block KV indices properties
self.assertEqual(metadata.block_kv_indices.device.type, "cuda")
self.assertEqual(metadata.block_kv_indices.dtype, torch.int32)
......@@ -915,9 +917,10 @@ class TestTRTLLMMLA(CustomTestCase):
# Should satisfy TRT-LLM and Triton constraints
trtllm_constraint = 128 // scenario["page_size"]
constraint_lcm = math.lcm(
trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK
triton_constraint = get_num_page_per_block_flashmla(
scenario["page_size"]
)
constraint_lcm = math.lcm(trtllm_constraint, triton_constraint)
self.assertEqual(
calculated_blocks % constraint_lcm,
0,
......@@ -965,7 +968,7 @@ class TestTRTLLMMLA(CustomTestCase):
# Initialize metadata
backend.init_forward_metadata(fb)
metadata = backend.forward_metadata
metadata = backend.forward_decode_metadata
# Verify KV indices structure
block_kv_indices = metadata.block_kv_indices
......@@ -1016,7 +1019,6 @@ class TestTRTLLMMLA(CustomTestCase):
# Verify CUDA graph buffers are allocated
self.assertIsNotNone(backend.decode_cuda_graph_kv_indices)
self.assertIsNotNone(backend.decode_cuda_graph_workspace)
# Test capture metadata
seq_lens = torch.full(
......@@ -1038,7 +1040,6 @@ class TestTRTLLMMLA(CustomTestCase):
self.assertIn(batch_size, backend.decode_cuda_graph_metadata)
capture_metadata = backend.decode_cuda_graph_metadata[batch_size]
self.assertIsNotNone(capture_metadata.workspace)
self.assertIsNotNone(capture_metadata.block_kv_indices)
# Test replay with different sequence lengths
......@@ -1061,11 +1062,8 @@ class TestTRTLLMMLA(CustomTestCase):
)
# Verify replay updated the metadata
replay_metadata = backend.forward_metadata
replay_metadata = backend.forward_decode_metadata
self.assertIsNotNone(replay_metadata)
self.assertEqual(
replay_metadata.workspace.data_ptr(), capture_metadata.workspace.data_ptr()
)
def test_metadata_consistency_across_calls(self):
"""Test metadata consistency across multiple forward calls."""
......@@ -1083,7 +1081,7 @@ class TestTRTLLMMLA(CustomTestCase):
config["batch_size"], seq_lens_1, backend, model_runner, config
)
backend.init_forward_metadata(fb_1)
metadata_1 = backend.forward_metadata
metadata_1 = backend.forward_decode_metadata
# Second call with same sequence lengths
seq_lens_2 = torch.tensor([32, 48], device=config["device"])
......@@ -1091,10 +1089,9 @@ class TestTRTLLMMLA(CustomTestCase):
config["batch_size"], seq_lens_2, backend, model_runner, config
)
backend.init_forward_metadata(fb_2)
metadata_2 = backend.forward_metadata
metadata_2 = backend.forward_decode_metadata
# Metadata structure should be consistent
self.assertEqual(metadata_1.workspace.shape, metadata_2.workspace.shape)
self.assertEqual(
metadata_1.block_kv_indices.shape, metadata_2.block_kv_indices.shape
)
......@@ -1105,10 +1102,9 @@ class TestTRTLLMMLA(CustomTestCase):
config["batch_size"], seq_lens_3, backend, model_runner, config
)
backend.init_forward_metadata(fb_3)
metadata_3 = backend.forward_metadata
metadata_3 = backend.forward_decode_metadata
# Should still have valid structure
self.assertIsNotNone(metadata_3.workspace)
self.assertIsNotNone(metadata_3.block_kv_indices)
self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment