"docs/vscode:/vscode.git/clone" did not exist on "21e61eb3a9d16a46245bd284fea3aa19e66772f5"
Unverified Commit 804d9f2e authored by Yubo Wang's avatar Yubo Wang Committed by GitHub
Browse files

Add unit test on page_size > 1 and mla and integration test for Flash Attention 3 (#4760)

parent a7c3f74b
...@@ -548,8 +548,9 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -548,8 +548,9 @@ class FlashAttentionBackend(AttentionBackend):
# Use Flash Attention for prefill # Use Flash Attention for prefill
if not self.use_mla: if not self.use_mla:
# Do multi-head attention # Do multi-head attention
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
key_cache, value_cache = kv_cache[0], kv_cache[1] layer.layer_id
)
key_cache = key_cache.view( key_cache = key_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim -1, self.page_size, layer.tp_k_head_num, layer.head_dim
) )
...@@ -592,7 +593,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -592,7 +593,6 @@ class FlashAttentionBackend(AttentionBackend):
c_kv_cache = c_kv.view( c_kv_cache = c_kv.view(
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
) )
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim] q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :] q_rope = q_all[:, :, layer.v_head_dim :]
...@@ -659,8 +659,10 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -659,8 +659,10 @@ class FlashAttentionBackend(AttentionBackend):
if not self.use_mla: if not self.use_mla:
# Do multi-head attention # Do multi-head attention
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
key_cache, value_cache = kv_cache[0], kv_cache[1] key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
)
key_cache = key_cache.view( key_cache = key_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim -1, self.page_size, layer.tp_k_head_num, layer.head_dim
) )
......
...@@ -63,10 +63,6 @@ from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config ...@@ -63,10 +63,6 @@ from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
UnquantizedEmbeddingMethod,
)
# Base quantization methods that don't depend on vllm # Base quantization methods that don't depend on vllm
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
...@@ -176,6 +172,13 @@ def get_linear_quant_method( ...@@ -176,6 +172,13 @@ def get_linear_quant_method(
prefix: str, prefix: str,
linear_method_cls: type, linear_method_cls: type,
): ):
# Move import here to avoid circular import. This is only used in monkey patching
# of vllm's QuantizationConfig.
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
UnquantizedEmbeddingMethod,
)
cloned_config = deepcopy(config) cloned_config = deepcopy(config)
parallel_lm_head_quantized = ( parallel_lm_head_quantized = (
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
......
import unittest
import torch
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.test.test_utils import CustomTestCase
class MockModelRunner:
def __init__(
self,
kv_lora_rank,
qk_rope_head_dim,
):
attention_arch = AttentionArch.MLA
self.device = "cuda"
self.dtype = torch.float16
context_len = 2048
self.model_config = type(
"ModelConfig",
(),
{
"context_len": context_len,
"attention_arch": attention_arch,
},
)
self.sliding_window_size = None
batch_size = 160
# Create a proper req_to_token_pool with the req_to_token attribute
self.req_to_token_pool = type(
"TokenPool",
(),
{
# A typical max_bs * max_context_len for cuda graph decode
"size": batch_size,
# Add req_to_token attribute
"req_to_token": torch.zeros(
batch_size, context_len, dtype=torch.int32, device=self.device
),
},
)
self.page_size = 1
max_total_num_tokens = batch_size * context_len
self.token_to_kv_pool = MLATokenToKVPool(
size=max_total_num_tokens,
page_size=self.page_size,
dtype=self.dtype,
kv_lora_rank=kv_lora_rank,
qk_rope_head_dim=qk_rope_head_dim,
layer_num=1, # only consider layer=1 for unit test
device=self.device,
enable_memory_saver=False,
)
class MockReqToTokenPool:
def __init__(self, batch_size, seq_len, device):
self.req_to_token = (
torch.arange(batch_size * seq_len, device=device)
.reshape(batch_size, seq_len)
.to(torch.int32)
)
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
class TestFlashAttentionMLABackend(CustomTestCase):
def setUp(self):
# Test parameters
self.batch_size = 2
self.seq_len = 360
self.num_heads = 2
self.device = "cuda"
self.dtype = torch.float16
self.kv_lora_rank = 512
self.q_lora_rank = 128
self.qk_rope_head_dim = 64
self.qk_head_dim = self.qk_rope_head_dim + self.kv_lora_rank
# Assume no rope scaling
self.scaling = self.qk_head_dim**-0.5
# Initialize model runner and backend
self._init_model_runner()
self.backend = FlashAttentionBackend(self.model_runner)
self.num_local_heads = 2
def _init_model_runner(self):
self.model_runner = MockModelRunner(
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
)
self.backend = FlashAttentionBackend(self.model_runner)
def _create_attention_layer(self):
"""Create attention layer for testing."""
self.attn_mqa = RadixAttention(
num_heads=self.num_local_heads,
head_dim=self.kv_lora_rank + self.qk_rope_head_dim,
scaling=self.scaling,
num_kv_heads=1,
layer_id=0,
v_head_dim=self.kv_lora_rank,
prefix="attn_mqa",
)
return self.attn_mqa
def _run_reference_forward(
self, mode, q, k, v, layer, forward_batch, expected_shape
):
"""Run reference forward pass using native backend."""
if mode == ForwardMode.EXTEND:
output = self.ref_backend.forward_extend(q, k, v, layer, forward_batch)
else: # ForwardMode.DECODE
output = self.ref_backend.forward_decode(q, k, v, layer, forward_batch)
return output.view(expected_shape)
def _verify_output(self, output, expected_shape):
"""Verify output tensor shape, dtype, and values."""
self.assertEqual(
output.shape,
expected_shape,
f"Expected shape {expected_shape}, got {output.shape}",
)
self.assertEqual(output.dtype, self.dtype)
self.assertEqual(output.device.type, "cuda")
self.assertEqual(
torch.isnan(output).sum().item(), 0, "Output contains NaN values"
)
def _create_forward_batch(self, mode, q_len=None, prefix_len=0):
"""Create a forward batch for testing based on mode and lengths."""
# Default to self.seq_len if not specified
q_len = q_len or self.seq_len
if mode == ForwardMode.EXTEND:
total_len = prefix_len + q_len
out_cache_start = prefix_len * self.batch_size
out_cache_end = total_len * self.batch_size
forward_batch = ForwardBatch(
batch_size=self.batch_size,
input_ids=torch.randint(
0, 100, (self.batch_size, q_len), device=self.device
),
out_cache_loc=torch.arange(
out_cache_start, out_cache_end, device=self.device
),
seq_lens_sum=self.batch_size * total_len,
forward_mode=mode,
req_pool_indices=torch.arange(self.batch_size, device=self.device),
seq_lens=torch.tensor(
[total_len] * self.batch_size, device=self.device
),
seq_lens_cpu=torch.tensor([total_len] * self.batch_size, device="cpu"),
extend_prefix_lens=torch.tensor(
[prefix_len] * self.batch_size, device=self.device
),
extend_prefix_lens_cpu=torch.tensor(
[prefix_len] * self.batch_size, device="cpu"
),
extend_seq_lens=torch.tensor(
[q_len] * self.batch_size, device=self.device
),
extend_seq_lens_cpu=torch.tensor(
[q_len] * self.batch_size, device="cpu"
),
attn_backend=self.backend,
)
else: # ForwardMode.DECODE
decode_len = q_len # typically 1 for decode mode
total_len = self.seq_len + decode_len
out_cache_start = self.batch_size * self.seq_len
out_cache_end = self.batch_size * total_len
forward_batch = ForwardBatch(
batch_size=self.batch_size,
input_ids=torch.randint(
0, 100, (self.batch_size, decode_len), device=self.device
),
out_cache_loc=torch.arange(
out_cache_start, out_cache_end, device=self.device
),
seq_lens_sum=self.batch_size * total_len,
forward_mode=mode,
req_pool_indices=torch.arange(self.batch_size, device=self.device),
seq_lens=torch.tensor(
[total_len] * self.batch_size, device=self.device
),
seq_lens_cpu=torch.tensor([total_len] * self.batch_size, device="cpu"),
attn_backend=self.backend,
)
# Add token pool from model runner to forward batch
forward_batch.req_to_token_pool = self.model_runner.req_to_token_pool
# Add KV cache from model runner to forward batch
forward_batch.token_to_kv_pool = self.model_runner.token_to_kv_pool
return forward_batch
def _setup_kv_cache(self, forward_batch, layer, cache_len):
"""Set up KV cache with prefix tokens."""
if cache_len <= 0:
return
# Create constant values for the prefix cache for easy debugging
latent_cache = torch.ones(
self.batch_size * cache_len,
1, # latent cache has only one head in MQA
self.kv_lora_rank + self.qk_rope_head_dim,
dtype=self.dtype,
device=self.device,
)
# Set the prefix KV cache
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
torch.arange(self.batch_size * cache_len, device=self.device),
latent_cache,
None,
)
def _run_attention_test(self, mode, q_len, prefix_len=0):
"""
Run an attention test with the specified parameters.
Args:
mode: ForwardMode.EXTEND or ForwardMode.DECODE
q_len: Length of the query sequence. For decode mode, q_len is 1.
prefix_len: Length of the prefix sequence for extend mode
"""
layer = self._create_attention_layer()
# Create forward batch and set up
forward_batch = self._create_forward_batch(mode, q_len, prefix_len)
# Create q, kv_compressed for testing
q_shape = (self.batch_size * q_len, self.num_heads, self.qk_head_dim)
kv_shape = (self.batch_size * q_len, self.qk_head_dim)
q = torch.randn(q_shape, dtype=self.dtype, device=self.device)
kv_compressed = torch.randn(kv_shape, dtype=self.dtype, device=self.device)
# v is not used for mqa, all values passed in through k
k = kv_compressed.unsqueeze(1)
v = torch.randn((1), dtype=self.dtype, device=self.device)
self._setup_kv_cache(forward_batch, layer, prefix_len)
self.backend.init_forward_metadata(forward_batch)
expected_shape = (
self.batch_size * q_len,
self.num_heads * self.kv_lora_rank,
)
if mode == ForwardMode.EXTEND:
output = self.backend.forward_extend(q, k, v, layer, forward_batch)
else:
output = self.backend.forward_decode(q, k, v, layer, forward_batch)
self._verify_output(output, expected_shape)
return output
def test_forward_extend(self):
"""Test the standard extend operation."""
self._run_attention_test(ForwardMode.EXTEND, q_len=self.seq_len)
def test_forward_decode(self):
"""Test the decode operation with cached tokens."""
self._run_attention_test(ForwardMode.DECODE, q_len=1)
def test_forward_extend_with_prefix(self):
"""Test extending from cached prefix tokens."""
prefix_len = self.seq_len // 2
extend_len = self.seq_len - prefix_len
self._run_attention_test(
ForwardMode.EXTEND, q_len=extend_len, prefix_len=prefix_len
)
if __name__ == "__main__":
unittest.main()
...@@ -28,6 +28,7 @@ suites = { ...@@ -28,6 +28,7 @@ suites = {
TestFile("test_chunked_prefill.py", 336), TestFile("test_chunked_prefill.py", 336),
TestFile("test_eagle_infer.py", 500), TestFile("test_eagle_infer.py", 500),
TestFile("test_ebnf_constrained.py"), TestFile("test_ebnf_constrained.py"),
TestFile("test_fa3.py", 5),
TestFile("test_fp8_kernel.py", 8), TestFile("test_fp8_kernel.py", 8),
TestFile("test_embedding_openai_server.py", 36), TestFile("test_embedding_openai_server.py", 36),
TestFile("test_hidden_states.py", 55), TestFile("test_hidden_states.py", 55),
......
import unittest
from types import SimpleNamespace
import requests
import torch
from sglang.srt.utils import get_device_sm, kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
"""
Integration test for python/sglang/srt/layers/attention/flashattention_backend.py
"""
# Change to your own model if testing model is not public.
MODEL_USED_FOR_TEST = DEFAULT_MODEL_NAME_FOR_TEST
MODEL_USED_FOR_TEST_MLA = DEFAULT_MLA_MODEL_NAME_FOR_TEST
# Setting data path to None uses default data path in few_shot_gsm8k eval test.
DATA_PATH = None
@unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher")
class BaseFlashAttentionTest(unittest.TestCase):
"""Base class for FlashAttention tests to reduce code duplication."""
model = MODEL_USED_FOR_TEST
base_url = DEFAULT_URL_FOR_TEST
accuracy_threshold = 0.62
@classmethod
def get_server_args(cls):
"""Return the arguments for the server launch. Override in subclasses."""
args = [
"--trust-remote-code",
"--enable-torch-compile",
"--attention-backend",
"fa3",
]
return args
@classmethod
def setUpClass(cls):
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=cls.get_server_args(),
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
data_path=DATA_PATH,
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
# Use the appropriate metric key based on the test class
metric_key = "accuracy"
self.assertGreater(metrics[metric_key], self.accuracy_threshold)
class TestFlashAttention3(BaseFlashAttentionTest):
"""Test FlashAttention3 with MLA model and CUDA graph enabled."""
@classmethod
def get_server_args(cls):
args = super().get_server_args()
args.extend(
[
"--cuda-graph-max-bs",
"2",
]
)
return args
class TestFlashAttention3DisableCudaGraph(BaseFlashAttentionTest):
"""Test FlashAttention3 with CUDA graph disabled."""
@classmethod
def get_server_args(cls):
args = super().get_server_args()
args.extend(
[
"--disable-cuda-graph",
]
)
return args
class TestFlashAttention3MLA(BaseFlashAttentionTest):
"""Test FlashAttention3 with MLA."""
model = MODEL_USED_FOR_TEST_MLA
@classmethod
def get_server_args(cls):
args = super().get_server_args()
args.extend(
[
"--cuda-graph-max-bs",
"2",
]
)
return args
class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
"""Test FlashAttention3 with speculative decode enabled."""
model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
@classmethod
def get_server_args(cls):
args = super().get_server_args()
args.extend(
[
"--cuda-graph-max-bs",
"2",
"--speculative-algorithm",
"EAGLE3",
"--speculative-draft",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
"3",
"--speculative-eagle-topk",
"1",
"--speculative-num-draft-tokens",
"3",
"--dtype",
"float16",
]
)
return args
def test_gsm8k(self):
"""
Override the test_gsm8k to further test for average speculative accept length.
"""
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
data_path=DATA_PATH,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 1.5)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment