Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
804d9f2e
Unverified
Commit
804d9f2e
authored
Apr 07, 2025
by
Yubo Wang
Committed by
GitHub
Apr 07, 2025
Browse files
Add unit test on page_size > 1 and mla and integration test for Flash Attention 3 (#4760)
parent
a7c3f74b
Changes
6
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
739 additions
and
230 deletions
+739
-230
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+7
-5
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+7
-4
python/sglang/test/attention/test_flashattn_backend.py
python/sglang/test/attention/test_flashattn_backend.py
+259
-221
python/sglang/test/attention/test_flashattn_mla_backend.py
python/sglang/test/attention/test_flashattn_mla_backend.py
+285
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_fa3.py
test/srt/test_fa3.py
+180
-0
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
804d9f2e
...
...
@@ -548,8 +548,9 @@ class FlashAttentionBackend(AttentionBackend):
# Use Flash Attention for prefill
if
not
self
.
use_mla
:
# Do multi-head attention
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
key_cache
,
value_cache
=
kv_cache
[
0
],
kv_cache
[
1
]
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
key_cache
=
key_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
...
...
@@ -592,7 +593,6 @@ class FlashAttentionBackend(AttentionBackend):
c_kv_cache
=
c_kv
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
v_head_dim
)
q_all
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
q_nope
=
q_all
[:,
:,
:
layer
.
v_head_dim
]
q_rope
=
q_all
[:,
:,
layer
.
v_head_dim
:]
...
...
@@ -659,8 +659,10 @@ class FlashAttentionBackend(AttentionBackend):
if
not
self
.
use_mla
:
# Do multi-head attention
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
key_cache
,
value_cache
=
kv_cache
[
0
],
kv_cache
[
1
]
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
key_cache
=
key_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
...
...
python/sglang/srt/layers/quantization/__init__.py
View file @
804d9f2e
...
...
@@ -63,10 +63,6 @@ from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
from
sglang.srt.layers.quantization.moe_wna16
import
MoeWNA16Config
from
sglang.srt.layers.quantization.w8a8_fp8
import
W8A8Fp8Config
from
sglang.srt.layers.quantization.w8a8_int8
import
W8A8Int8Config
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
UnquantizedEmbeddingMethod
,
)
# Base quantization methods that don't depend on vllm
BASE_QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
...
...
@@ -176,6 +172,13 @@ def get_linear_quant_method(
prefix
:
str
,
linear_method_cls
:
type
,
):
# Move import here to avoid circular import. This is only used in monkey patching
# of vllm's QuantizationConfig.
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
UnquantizedEmbeddingMethod
,
)
cloned_config
=
deepcopy
(
config
)
parallel_lm_head_quantized
=
(
isinstance
(
layer
,
ParallelLMHead
)
and
cloned_config
.
lm_head_quantized
...
...
python/sglang/test/attention/test_flashattn_backend.py
View file @
804d9f2e
This diff is collapsed.
Click to expand it.
python/sglang/test/attention/test_flashattn_mla_backend.py
0 → 100644
View file @
804d9f2e
import
unittest
import
torch
from
sglang.srt.configs.model_config
import
AttentionArch
from
sglang.srt.layers.attention.flashattention_backend
import
FlashAttentionBackend
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.mem_cache.memory_pool
import
MLATokenToKVPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.test.test_utils
import
CustomTestCase
class
MockModelRunner
:
def
__init__
(
self
,
kv_lora_rank
,
qk_rope_head_dim
,
):
attention_arch
=
AttentionArch
.
MLA
self
.
device
=
"cuda"
self
.
dtype
=
torch
.
float16
context_len
=
2048
self
.
model_config
=
type
(
"ModelConfig"
,
(),
{
"context_len"
:
context_len
,
"attention_arch"
:
attention_arch
,
},
)
self
.
sliding_window_size
=
None
batch_size
=
160
# Create a proper req_to_token_pool with the req_to_token attribute
self
.
req_to_token_pool
=
type
(
"TokenPool"
,
(),
{
# A typical max_bs * max_context_len for cuda graph decode
"size"
:
batch_size
,
# Add req_to_token attribute
"req_to_token"
:
torch
.
zeros
(
batch_size
,
context_len
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
},
)
self
.
page_size
=
1
max_total_num_tokens
=
batch_size
*
context_len
self
.
token_to_kv_pool
=
MLATokenToKVPool
(
size
=
max_total_num_tokens
,
page_size
=
self
.
page_size
,
dtype
=
self
.
dtype
,
kv_lora_rank
=
kv_lora_rank
,
qk_rope_head_dim
=
qk_rope_head_dim
,
layer_num
=
1
,
# only consider layer=1 for unit test
device
=
self
.
device
,
enable_memory_saver
=
False
,
)
class
MockReqToTokenPool
:
def
__init__
(
self
,
batch_size
,
seq_len
,
device
):
self
.
req_to_token
=
(
torch
.
arange
(
batch_size
*
seq_len
,
device
=
device
)
.
reshape
(
batch_size
,
seq_len
)
.
to
(
torch
.
int32
)
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
"Test requires CUDA"
)
class
TestFlashAttentionMLABackend
(
CustomTestCase
):
def
setUp
(
self
):
# Test parameters
self
.
batch_size
=
2
self
.
seq_len
=
360
self
.
num_heads
=
2
self
.
device
=
"cuda"
self
.
dtype
=
torch
.
float16
self
.
kv_lora_rank
=
512
self
.
q_lora_rank
=
128
self
.
qk_rope_head_dim
=
64
self
.
qk_head_dim
=
self
.
qk_rope_head_dim
+
self
.
kv_lora_rank
# Assume no rope scaling
self
.
scaling
=
self
.
qk_head_dim
**-
0.5
# Initialize model runner and backend
self
.
_init_model_runner
()
self
.
backend
=
FlashAttentionBackend
(
self
.
model_runner
)
self
.
num_local_heads
=
2
def
_init_model_runner
(
self
):
self
.
model_runner
=
MockModelRunner
(
kv_lora_rank
=
self
.
kv_lora_rank
,
qk_rope_head_dim
=
self
.
qk_rope_head_dim
,
)
self
.
backend
=
FlashAttentionBackend
(
self
.
model_runner
)
def
_create_attention_layer
(
self
):
"""Create attention layer for testing."""
self
.
attn_mqa
=
RadixAttention
(
num_heads
=
self
.
num_local_heads
,
head_dim
=
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
scaling
=
self
.
scaling
,
num_kv_heads
=
1
,
layer_id
=
0
,
v_head_dim
=
self
.
kv_lora_rank
,
prefix
=
"attn_mqa"
,
)
return
self
.
attn_mqa
def
_run_reference_forward
(
self
,
mode
,
q
,
k
,
v
,
layer
,
forward_batch
,
expected_shape
):
"""Run reference forward pass using native backend."""
if
mode
==
ForwardMode
.
EXTEND
:
output
=
self
.
ref_backend
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
)
else
:
# ForwardMode.DECODE
output
=
self
.
ref_backend
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
)
return
output
.
view
(
expected_shape
)
def
_verify_output
(
self
,
output
,
expected_shape
):
"""Verify output tensor shape, dtype, and values."""
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
f
"Expected shape
{
expected_shape
}
, got
{
output
.
shape
}
"
,
)
self
.
assertEqual
(
output
.
dtype
,
self
.
dtype
)
self
.
assertEqual
(
output
.
device
.
type
,
"cuda"
)
self
.
assertEqual
(
torch
.
isnan
(
output
).
sum
().
item
(),
0
,
"Output contains NaN values"
)
def
_create_forward_batch
(
self
,
mode
,
q_len
=
None
,
prefix_len
=
0
):
"""Create a forward batch for testing based on mode and lengths."""
# Default to self.seq_len if not specified
q_len
=
q_len
or
self
.
seq_len
if
mode
==
ForwardMode
.
EXTEND
:
total_len
=
prefix_len
+
q_len
out_cache_start
=
prefix_len
*
self
.
batch_size
out_cache_end
=
total_len
*
self
.
batch_size
forward_batch
=
ForwardBatch
(
batch_size
=
self
.
batch_size
,
input_ids
=
torch
.
randint
(
0
,
100
,
(
self
.
batch_size
,
q_len
),
device
=
self
.
device
),
out_cache_loc
=
torch
.
arange
(
out_cache_start
,
out_cache_end
,
device
=
self
.
device
),
seq_lens_sum
=
self
.
batch_size
*
total_len
,
forward_mode
=
mode
,
req_pool_indices
=
torch
.
arange
(
self
.
batch_size
,
device
=
self
.
device
),
seq_lens
=
torch
.
tensor
(
[
total_len
]
*
self
.
batch_size
,
device
=
self
.
device
),
seq_lens_cpu
=
torch
.
tensor
([
total_len
]
*
self
.
batch_size
,
device
=
"cpu"
),
extend_prefix_lens
=
torch
.
tensor
(
[
prefix_len
]
*
self
.
batch_size
,
device
=
self
.
device
),
extend_prefix_lens_cpu
=
torch
.
tensor
(
[
prefix_len
]
*
self
.
batch_size
,
device
=
"cpu"
),
extend_seq_lens
=
torch
.
tensor
(
[
q_len
]
*
self
.
batch_size
,
device
=
self
.
device
),
extend_seq_lens_cpu
=
torch
.
tensor
(
[
q_len
]
*
self
.
batch_size
,
device
=
"cpu"
),
attn_backend
=
self
.
backend
,
)
else
:
# ForwardMode.DECODE
decode_len
=
q_len
# typically 1 for decode mode
total_len
=
self
.
seq_len
+
decode_len
out_cache_start
=
self
.
batch_size
*
self
.
seq_len
out_cache_end
=
self
.
batch_size
*
total_len
forward_batch
=
ForwardBatch
(
batch_size
=
self
.
batch_size
,
input_ids
=
torch
.
randint
(
0
,
100
,
(
self
.
batch_size
,
decode_len
),
device
=
self
.
device
),
out_cache_loc
=
torch
.
arange
(
out_cache_start
,
out_cache_end
,
device
=
self
.
device
),
seq_lens_sum
=
self
.
batch_size
*
total_len
,
forward_mode
=
mode
,
req_pool_indices
=
torch
.
arange
(
self
.
batch_size
,
device
=
self
.
device
),
seq_lens
=
torch
.
tensor
(
[
total_len
]
*
self
.
batch_size
,
device
=
self
.
device
),
seq_lens_cpu
=
torch
.
tensor
([
total_len
]
*
self
.
batch_size
,
device
=
"cpu"
),
attn_backend
=
self
.
backend
,
)
# Add token pool from model runner to forward batch
forward_batch
.
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
# Add KV cache from model runner to forward batch
forward_batch
.
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
return
forward_batch
def
_setup_kv_cache
(
self
,
forward_batch
,
layer
,
cache_len
):
"""Set up KV cache with prefix tokens."""
if
cache_len
<=
0
:
return
# Create constant values for the prefix cache for easy debugging
latent_cache
=
torch
.
ones
(
self
.
batch_size
*
cache_len
,
1
,
# latent cache has only one head in MQA
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
)
# Set the prefix KV cache
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
torch
.
arange
(
self
.
batch_size
*
cache_len
,
device
=
self
.
device
),
latent_cache
,
None
,
)
def
_run_attention_test
(
self
,
mode
,
q_len
,
prefix_len
=
0
):
"""
Run an attention test with the specified parameters.
Args:
mode: ForwardMode.EXTEND or ForwardMode.DECODE
q_len: Length of the query sequence. For decode mode, q_len is 1.
prefix_len: Length of the prefix sequence for extend mode
"""
layer
=
self
.
_create_attention_layer
()
# Create forward batch and set up
forward_batch
=
self
.
_create_forward_batch
(
mode
,
q_len
,
prefix_len
)
# Create q, kv_compressed for testing
q_shape
=
(
self
.
batch_size
*
q_len
,
self
.
num_heads
,
self
.
qk_head_dim
)
kv_shape
=
(
self
.
batch_size
*
q_len
,
self
.
qk_head_dim
)
q
=
torch
.
randn
(
q_shape
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
kv_compressed
=
torch
.
randn
(
kv_shape
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
# v is not used for mqa, all values passed in through k
k
=
kv_compressed
.
unsqueeze
(
1
)
v
=
torch
.
randn
((
1
),
dtype
=
self
.
dtype
,
device
=
self
.
device
)
self
.
_setup_kv_cache
(
forward_batch
,
layer
,
prefix_len
)
self
.
backend
.
init_forward_metadata
(
forward_batch
)
expected_shape
=
(
self
.
batch_size
*
q_len
,
self
.
num_heads
*
self
.
kv_lora_rank
,
)
if
mode
==
ForwardMode
.
EXTEND
:
output
=
self
.
backend
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
)
else
:
output
=
self
.
backend
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
)
self
.
_verify_output
(
output
,
expected_shape
)
return
output
def
test_forward_extend
(
self
):
"""Test the standard extend operation."""
self
.
_run_attention_test
(
ForwardMode
.
EXTEND
,
q_len
=
self
.
seq_len
)
def
test_forward_decode
(
self
):
"""Test the decode operation with cached tokens."""
self
.
_run_attention_test
(
ForwardMode
.
DECODE
,
q_len
=
1
)
def
test_forward_extend_with_prefix
(
self
):
"""Test extending from cached prefix tokens."""
prefix_len
=
self
.
seq_len
//
2
extend_len
=
self
.
seq_len
-
prefix_len
self
.
_run_attention_test
(
ForwardMode
.
EXTEND
,
q_len
=
extend_len
,
prefix_len
=
prefix_len
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/run_suite.py
View file @
804d9f2e
...
...
@@ -28,6 +28,7 @@ suites = {
TestFile
(
"test_chunked_prefill.py"
,
336
),
TestFile
(
"test_eagle_infer.py"
,
500
),
TestFile
(
"test_ebnf_constrained.py"
),
TestFile
(
"test_fa3.py"
,
5
),
TestFile
(
"test_fp8_kernel.py"
,
8
),
TestFile
(
"test_embedding_openai_server.py"
,
36
),
TestFile
(
"test_hidden_states.py"
,
55
),
...
...
test/srt/test_fa3.py
0 → 100644
View file @
804d9f2e
import
unittest
from
types
import
SimpleNamespace
import
requests
import
torch
from
sglang.srt.utils
import
get_device_sm
,
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_utils
import
(
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
DEFAULT_MLA_MODEL_NAME_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
"""
Integration test for python/sglang/srt/layers/attention/flashattention_backend.py
"""
# Change to your own model if testing model is not public.
MODEL_USED_FOR_TEST
=
DEFAULT_MODEL_NAME_FOR_TEST
MODEL_USED_FOR_TEST_MLA
=
DEFAULT_MLA_MODEL_NAME_FOR_TEST
# Setting data path to None uses default data path in few_shot_gsm8k eval test.
DATA_PATH
=
None
@
unittest
.
skipIf
(
get_device_sm
()
<
90
,
"Test requires CUDA SM 90 or higher"
)
class
BaseFlashAttentionTest
(
unittest
.
TestCase
):
"""Base class for FlashAttention tests to reduce code duplication."""
model
=
MODEL_USED_FOR_TEST
base_url
=
DEFAULT_URL_FOR_TEST
accuracy_threshold
=
0.62
@
classmethod
def
get_server_args
(
cls
):
"""Return the arguments for the server launch. Override in subclasses."""
args
=
[
"--trust-remote-code"
,
"--enable-torch-compile"
,
"--attention-backend"
,
"fa3"
,
]
return
args
@
classmethod
def
setUpClass
(
cls
):
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
cls
.
get_server_args
(),
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
data_path
=
DATA_PATH
,
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
metrics
)
# Use the appropriate metric key based on the test class
metric_key
=
"accuracy"
self
.
assertGreater
(
metrics
[
metric_key
],
self
.
accuracy_threshold
)
class
TestFlashAttention3
(
BaseFlashAttentionTest
):
"""Test FlashAttention3 with MLA model and CUDA graph enabled."""
@
classmethod
def
get_server_args
(
cls
):
args
=
super
().
get_server_args
()
args
.
extend
(
[
"--cuda-graph-max-bs"
,
"2"
,
]
)
return
args
class
TestFlashAttention3DisableCudaGraph
(
BaseFlashAttentionTest
):
"""Test FlashAttention3 with CUDA graph disabled."""
@
classmethod
def
get_server_args
(
cls
):
args
=
super
().
get_server_args
()
args
.
extend
(
[
"--disable-cuda-graph"
,
]
)
return
args
class
TestFlashAttention3MLA
(
BaseFlashAttentionTest
):
"""Test FlashAttention3 with MLA."""
model
=
MODEL_USED_FOR_TEST_MLA
@
classmethod
def
get_server_args
(
cls
):
args
=
super
().
get_server_args
()
args
.
extend
(
[
"--cuda-graph-max-bs"
,
"2"
,
]
)
return
args
class
TestFlashAttention3SpeculativeDecode
(
BaseFlashAttentionTest
):
"""Test FlashAttention3 with speculative decode enabled."""
model
=
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
@
classmethod
def
get_server_args
(
cls
):
args
=
super
().
get_server_args
()
args
.
extend
(
[
"--cuda-graph-max-bs"
,
"2"
,
"--speculative-algorithm"
,
"EAGLE3"
,
"--speculative-draft"
,
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
"--speculative-num-steps"
,
"3"
,
"--speculative-eagle-topk"
,
"1"
,
"--speculative-num-draft-tokens"
,
"3"
,
"--dtype"
,
"float16"
,
]
)
return
args
def
test_gsm8k
(
self
):
"""
Override the test_gsm8k to further test for average speculative accept length.
"""
requests
.
get
(
self
.
base_url
+
"/flush_cache"
)
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
DATA_PATH
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
metrics
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.60
)
server_info
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
avg_spec_accept_length
=
server_info
.
json
()[
"avg_spec_accept_length"
]
print
(
f
"
{
avg_spec_accept_length
=
}
"
)
self
.
assertGreater
(
avg_spec_accept_length
,
1.5
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment