Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
804d9f2e
Unverified
Commit
804d9f2e
authored
Apr 07, 2025
by
Yubo Wang
Committed by
GitHub
Apr 07, 2025
Browse files
Add unit test on page_size > 1 and mla and integration test for Flash Attention 3 (#4760)
parent
a7c3f74b
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
739 additions
and
230 deletions
+739
-230
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+7
-5
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+7
-4
python/sglang/test/attention/test_flashattn_backend.py
python/sglang/test/attention/test_flashattn_backend.py
+259
-221
python/sglang/test/attention/test_flashattn_mla_backend.py
python/sglang/test/attention/test_flashattn_mla_backend.py
+285
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_fa3.py
test/srt/test_fa3.py
+180
-0
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
804d9f2e
...
@@ -548,8 +548,9 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -548,8 +548,9 @@ class FlashAttentionBackend(AttentionBackend):
# Use Flash Attention for prefill
# Use Flash Attention for prefill
if
not
self
.
use_mla
:
if
not
self
.
use_mla
:
# Do multi-head attention
# Do multi-head attention
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
key_cache
,
value_cache
=
kv_cache
[
0
],
kv_cache
[
1
]
layer
.
layer_id
)
key_cache
=
key_cache
.
view
(
key_cache
=
key_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
)
...
@@ -592,7 +593,6 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -592,7 +593,6 @@ class FlashAttentionBackend(AttentionBackend):
c_kv_cache
=
c_kv
.
view
(
c_kv_cache
=
c_kv
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
v_head_dim
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
v_head_dim
)
)
q_all
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
q_all
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
q_nope
=
q_all
[:,
:,
:
layer
.
v_head_dim
]
q_nope
=
q_all
[:,
:,
:
layer
.
v_head_dim
]
q_rope
=
q_all
[:,
:,
layer
.
v_head_dim
:]
q_rope
=
q_all
[:,
:,
layer
.
v_head_dim
:]
...
@@ -659,8 +659,10 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -659,8 +659,10 @@ class FlashAttentionBackend(AttentionBackend):
if
not
self
.
use_mla
:
if
not
self
.
use_mla
:
# Do multi-head attention
# Do multi-head attention
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
key_cache
,
value_cache
=
kv_cache
[
0
],
kv_cache
[
1
]
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
key_cache
=
key_cache
.
view
(
key_cache
=
key_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
)
...
...
python/sglang/srt/layers/quantization/__init__.py
View file @
804d9f2e
...
@@ -63,10 +63,6 @@ from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
...
@@ -63,10 +63,6 @@ from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
from
sglang.srt.layers.quantization.moe_wna16
import
MoeWNA16Config
from
sglang.srt.layers.quantization.moe_wna16
import
MoeWNA16Config
from
sglang.srt.layers.quantization.w8a8_fp8
import
W8A8Fp8Config
from
sglang.srt.layers.quantization.w8a8_fp8
import
W8A8Fp8Config
from
sglang.srt.layers.quantization.w8a8_int8
import
W8A8Int8Config
from
sglang.srt.layers.quantization.w8a8_int8
import
W8A8Int8Config
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
UnquantizedEmbeddingMethod
,
)
# Base quantization methods that don't depend on vllm
# Base quantization methods that don't depend on vllm
BASE_QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
BASE_QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
...
@@ -176,6 +172,13 @@ def get_linear_quant_method(
...
@@ -176,6 +172,13 @@ def get_linear_quant_method(
prefix
:
str
,
prefix
:
str
,
linear_method_cls
:
type
,
linear_method_cls
:
type
,
):
):
# Move import here to avoid circular import. This is only used in monkey patching
# of vllm's QuantizationConfig.
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
UnquantizedEmbeddingMethod
,
)
cloned_config
=
deepcopy
(
config
)
cloned_config
=
deepcopy
(
config
)
parallel_lm_head_quantized
=
(
parallel_lm_head_quantized
=
(
isinstance
(
layer
,
ParallelLMHead
)
and
cloned_config
.
lm_head_quantized
isinstance
(
layer
,
ParallelLMHead
)
and
cloned_config
.
lm_head_quantized
...
...
python/sglang/test/attention/test_flashattn_backend.py
View file @
804d9f2e
...
@@ -2,60 +2,109 @@ import unittest
...
@@ -2,60 +2,109 @@ import unittest
import
torch
import
torch
from
sglang.srt.configs.model_config
import
AttentionArch
from
sglang.srt.layers.attention.flashattention_backend
import
FlashAttentionBackend
from
sglang.srt.layers.attention.flashattention_backend
import
FlashAttentionBackend
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.mem_cache.memory_pool
import
MHATokenToKVPool
from
sglang.srt.mem_cache.memory_pool
import
MHATokenToKVPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.model_runner
import
ServerArgs
from
sglang.test.test_utils
import
CustomTestCase
from
sglang.test.test_utils
import
CustomTestCase
class
MockModelRunner
:
class
MockModelRunner
:
model_config
=
type
(
def
__init__
(
"ModelConfig"
,
(),
{
"context_len"
:
2048
,
"is_multimodal"
:
False
}
self
,
page_size
=
1
,
num_heads
=
2
,
head_dim
=
8
,
):
self
.
device
=
"cuda"
self
.
dtype
=
torch
.
float16
attention_arch
=
AttentionArch
.
MHA
# Max batch size for the test.
max_batch_size
=
160
# Total tokens(prefix + extend + decode) in the test should not exceed this length.
max_context_len
=
2048
self
.
model_config
=
type
(
"ModelConfig"
,
(),
{
"context_len"
:
max_context_len
,
"is_multimodal"
:
False
,
"attention_arch"
:
attention_arch
,
},
)
)
sliding_window_size
=
None
self
.
sliding_window_size
=
None
self
.
device
=
self
.
device
def
__init__
(
self
,
device
=
"cuda"
):
# Create a large enough req_to_token_pool to fit the test usage.
self
.
device
=
device
# Create a proper req_to_token_pool with the req_to_token attribute
self
.
req_to_token_pool
=
type
(
self
.
req_to_token_pool
=
type
(
"TokenPool"
,
"TokenPool"
,
(),
(),
{
{
"size"
:
160
,
# a typical max_bs * max_context_len for cuda graph decode
# A typical max_bs * max_context_len for cuda graph decode
"size"
:
max_batch_size
,
# Add req_to_token attribute
"req_to_token"
:
torch
.
zeros
(
"req_to_token"
:
torch
.
zeros
(
160
,
2048
,
dtype
=
torch
.
int32
,
device
=
device
max_batch_size
,
),
# Add req_to_token attribute
max_context_len
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
},
},
)
)
self
.
page_size
=
page_size
max_total_num_tokens
=
max_batch_size
*
max_context_len
class
MockReqToTokenPool
:
self
.
token_to_kv_pool
=
MHATokenToKVPool
(
def
__init__
(
self
,
batch_size
,
seq_len
,
device
):
size
=
max_total_num_tokens
,
self
.
req_to_token
=
(
page_size
=
page_size
,
torch
.
arange
(
batch_size
*
seq_len
,
device
=
device
)
dtype
=
self
.
dtype
,
.
reshape
(
batch_size
,
seq_len
)
head_num
=
num_heads
,
.
to
(
torch
.
int32
)
head_dim
=
head_dim
,
layer_num
=
1
,
# only consider layer=1 for unit test
device
=
self
.
device
,
enable_memory_saver
=
False
,
)
)
# Required by torch native backend
self
.
server_args
=
ServerArgs
(
model_path
=
"fake_model_path"
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
"Test requires CUDA"
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
"Test requires CUDA"
)
class
TestFlashAttentionBackend
(
CustomTestCase
):
class
TestFlashAttentionBackend
(
CustomTestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
"""Set up test fixtures before each test method."""
# Test parameters
self
.
model_runner
=
MockModelRunner
()
self
.
backend
=
FlashAttentionBackend
(
self
.
model_runner
)
# Common test parameters
self
.
batch_size
=
2
self
.
batch_size
=
2
self
.
seq_len
=
4
self
.
seq_len
=
256
self
.
num_heads
=
2
self
.
num_heads
=
2
self
.
head_dim
=
8
self
.
head_dim
=
8
self
.
device
=
"cuda"
self
.
device
=
"cuda"
self
.
dtype
=
torch
.
float16
self
.
dtype
=
torch
.
float16
def
_init_model_runner
(
self
,
page_size
=
1
):
self
.
model_runner
=
MockModelRunner
(
page_size
=
page_size
,
num_heads
=
self
.
num_heads
,
head_dim
=
self
.
head_dim
,
)
self
.
backend
=
FlashAttentionBackend
(
self
.
model_runner
)
self
.
ref_backend
=
TorchNativeAttnBackend
(
self
.
model_runner
)
self
.
model_runner
.
model_config
.
num_attention_heads
=
self
.
num_heads
def
_mock_write_to_req_to_token_pool
(
self
,
batch_size
,
seq_len
,
page_size
):
# if page_size > 1, the token pool stores the index to the page.
# so we need to multiply the index by page_size.
self
.
req_to_token
=
(
torch
.
arange
(
0
,
batch_size
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)[:,
None
]
*
seq_len
+
torch
.
arange
(
0
,
seq_len
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)[
None
,
:]
+
page_size
)
self
.
model_runner
.
req_to_token_pool
.
req_to_token
[:
batch_size
,
:
seq_len
]
=
(
self
.
req_to_token
)
def
_create_attention_layer
(
self
):
def
_create_attention_layer
(
self
):
"""
Helper method to c
reate
an
attention layer."""
"""
C
reate attention layer
for testing
."""
return
RadixAttention
(
return
RadixAttention
(
num_heads
=
self
.
num_heads
,
num_heads
=
self
.
num_heads
,
head_dim
=
self
.
head_dim
,
head_dim
=
self
.
head_dim
,
...
@@ -64,47 +113,27 @@ class TestFlashAttentionBackend(CustomTestCase):
...
@@ -64,47 +113,27 @@ class TestFlashAttentionBackend(CustomTestCase):
layer_id
=
0
,
layer_id
=
0
,
)
)
def
_create_kv_pool
(
self
,
size
):
"""Helper method to create a KV pool."""
return
MHATokenToKVPool
(
size
=
size
,
page_size
=
1
,
# only consider page=1 for unit test
dtype
=
self
.
dtype
,
head_num
=
self
.
num_heads
,
head_dim
=
self
.
head_dim
,
layer_num
=
1
,
# only consider layer=1 for unit test
device
=
self
.
device
,
enable_memory_saver
=
False
,
)
def
_create_qkv_tensors
(
self
,
tokens_len
):
def
_create_qkv_tensors
(
self
,
tokens_len
):
"""Helper method to create q, k, v tensors."""
"""Create q, k, v tensors for testing."""
shape
=
(
tokens_len
,
self
.
num_heads
,
self
.
head_dim
)
return
(
return
(
torch
.
randn
(
torch
.
randn
(
shape
,
dtype
=
self
.
dtype
,
device
=
self
.
device
),
tokens_len
,
torch
.
randn
(
shape
,
dtype
=
self
.
dtype
,
device
=
self
.
device
),
self
.
num_heads
,
torch
.
randn
(
shape
,
dtype
=
self
.
dtype
,
device
=
self
.
device
),
self
.
head_dim
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
),
torch
.
randn
(
tokens_len
,
self
.
num_heads
,
self
.
head_dim
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
),
torch
.
randn
(
tokens_len
,
self
.
num_heads
,
self
.
head_dim
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
),
)
)
def
_verify_output
(
self
,
output
,
expected_shape
):
def
_run_reference_forward
(
"""Helper method to verify output."""
self
,
mode
,
q
,
k
,
v
,
layer
,
forward_batch
,
expected_shape
):
"""Run reference forward pass using native backend."""
if
mode
==
ForwardMode
.
EXTEND
:
output
=
self
.
ref_backend
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
)
else
:
# ForwardMode.DECODE
output
=
self
.
ref_backend
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
)
return
output
.
view
(
expected_shape
)
def
_verify_output
(
self
,
output
,
expected_shape
,
output_ref
=
None
):
"""Verify output tensor shape, dtype, and values."""
self
.
assertEqual
(
self
.
assertEqual
(
output
.
shape
,
output
.
shape
,
expected_shape
,
expected_shape
,
...
@@ -116,161 +145,110 @@ class TestFlashAttentionBackend(CustomTestCase):
...
@@ -116,161 +145,110 @@ class TestFlashAttentionBackend(CustomTestCase):
torch
.
isnan
(
output
).
sum
().
item
(),
0
,
"Output contains NaN values"
torch
.
isnan
(
output
).
sum
().
item
(),
0
,
"Output contains NaN values"
)
)
def
test_forward_extend
(
self
):
if
output_ref
is
not
None
:
"""Test the standard extend operation."""
if
not
torch
.
allclose
(
output
,
output_ref
,
atol
=
1e-1
,
rtol
=
0.0
):
# Create test inputs
# Check where the values differ beyond the given tolerances
q
,
k
,
v
=
self
.
_create_qkv_tensors
(
self
.
batch_size
*
self
.
seq_len
)
diff_mask
=
~
torch
.
isclose
(
output
,
output_ref
,
atol
=
1e-1
,
rtol
=
0.0
)
# Create attention layer
layer
=
self
.
_create_attention_layer
()
# Create forward batch
forward_batch
=
ForwardBatch
(
batch_size
=
self
.
batch_size
,
input_ids
=
torch
.
randint
(
0
,
100
,
(
self
.
batch_size
,
self
.
seq_len
),
device
=
self
.
device
),
out_cache_loc
=
torch
.
arange
(
self
.
batch_size
*
self
.
seq_len
,
device
=
self
.
device
),
seq_lens_sum
=
self
.
batch_size
*
self
.
seq_len
,
forward_mode
=
ForwardMode
.
EXTEND
,
req_pool_indices
=
torch
.
arange
(
self
.
batch_size
,
device
=
self
.
device
),
seq_lens
=
torch
.
tensor
([
self
.
seq_len
]
*
self
.
batch_size
,
device
=
self
.
device
),
# 0 prefix, 4 extend
extend_prefix_lens
=
torch
.
tensor
([
0
]
*
self
.
batch_size
,
device
=
self
.
device
),
extend_seq_lens
=
torch
.
tensor
([
4
]
*
self
.
batch_size
,
device
=
self
.
device
),
attn_backend
=
self
.
backend
,
)
# Add token pool and KV cache
# Find the first index where the difference occurs
forward_batch
.
req_to_token_pool
=
MockReqToTokenPool
(
if
diff_mask
.
any
():
self
.
batch_size
,
self
.
seq_len
,
self
.
device
first_mismatch_idx
=
diff_mask
.
nonzero
()[
0
]
print
(
"First mismatch at index:"
,
tuple
(
first_mismatch_idx
.
tolist
())
)
)
forward_batch
.
token_to_kv_pool
=
self
.
_create_kv_pool
(
print
(
"output:"
,
output
[
tuple
(
first_mismatch_idx
.
tolist
())])
self
.
batch_size
*
self
.
seq_len
print
(
"output_ref:"
,
output_ref
[
tuple
(
first_mismatch_idx
.
tolist
())])
raise
AssertionError
(
"Attention output is not close to the torch native backend output"
)
)
# Initialize forward metadata before running the attention
def
_create_forward_batch
(
self
,
mode
,
q_len
=
None
,
prefix_len
=
0
,
page_size
=
1
):
self
.
backend
.
init_forward_metadata
(
forward_batch
)
"""Create a forward batch for testing based on mode and lengths."""
self
.
_init_model_runner
(
page_size
=
page_size
)
# Run forward_extend
output
=
self
.
backend
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
)
# Verify output
# Default to self.seq_len if not specified
expected_shape
=
(
q_len
=
q_len
or
self
.
seq_len
self
.
batch_size
*
self
.
seq_len
,
self
.
num_heads
*
self
.
head_dim
,
)
self
.
_verify_output
(
output
,
expected_shape
)
def
test_forward_decode
(
self
):
if
mode
==
ForwardMode
.
EXTEND
:
"""Test the decode operation with cached tokens."""
total_len
=
prefix_len
+
q_len
# For decode, we only have one token per sequence
out_cache_start
=
prefix_len
*
self
.
batch_size
decode_len
=
1
out_cache_end
=
total_len
*
self
.
batch_size
curr_seq_len
=
self
.
seq_len
+
decode_len
# Create test inputs
q
,
k
,
v
=
self
.
_create_qkv_tensors
(
self
.
batch_size
*
decode_len
)
# Create attention layer
layer
=
self
.
_create_attention_layer
()
# Create forward batch
forward_batch
=
ForwardBatch
(
forward_batch
=
ForwardBatch
(
batch_size
=
self
.
batch_size
,
batch_size
=
self
.
batch_size
,
input_ids
=
torch
.
randint
(
input_ids
=
torch
.
randint
(
0
,
100
,
(
self
.
batch_size
,
decode
_len
),
device
=
self
.
device
0
,
100
,
(
self
.
batch_size
,
q
_len
),
device
=
self
.
device
),
),
out_cache_loc
=
torch
.
arange
(
out_cache_loc
=
torch
.
arange
(
self
.
batch_size
*
self
.
seq_len
,
out_cache_start
,
out_cache_end
,
device
=
self
.
device
self
.
batch_size
*
curr_seq_len
,
device
=
self
.
device
,
),
),
seq_lens_sum
=
self
.
batch_size
*
curr_seq
_len
,
seq_lens_sum
=
self
.
batch_size
*
total
_len
,
forward_mode
=
ForwardMode
.
DECODE
,
forward_mode
=
mode
,
req_pool_indices
=
torch
.
arange
(
self
.
batch_size
,
device
=
self
.
device
),
req_pool_indices
=
torch
.
arange
(
self
.
batch_size
,
device
=
self
.
device
),
seq_lens
=
torch
.
tensor
([
curr_seq_len
]
*
self
.
batch_size
,
device
=
self
.
device
),
seq_lens
=
torch
.
tensor
(
[
total_len
]
*
self
.
batch_size
,
device
=
self
.
device
),
seq_lens_cpu
=
torch
.
tensor
([
total_len
]
*
self
.
batch_size
,
device
=
"cpu"
),
extend_prefix_lens
=
torch
.
tensor
(
[
prefix_len
]
*
self
.
batch_size
,
device
=
self
.
device
),
extend_prefix_lens_cpu
=
torch
.
tensor
(
[
prefix_len
]
*
self
.
batch_size
,
device
=
"cpu"
),
extend_seq_lens
=
torch
.
tensor
(
[
q_len
]
*
self
.
batch_size
,
device
=
self
.
device
),
extend_seq_lens_cpu
=
torch
.
tensor
(
[
q_len
]
*
self
.
batch_size
,
device
=
"cpu"
),
attn_backend
=
self
.
backend
,
attn_backend
=
self
.
backend
,
)
)
else
:
# ForwardMode.DECODE
decode_len
=
q_len
# Assuming 1 for decode testing
total_len
=
self
.
seq_len
+
decode_len
if
mode
==
ForwardMode
.
DECODE
and
page_size
>
1
:
# Get next page_size multiple of self.seq_len
out_cache_start
=
(
self
.
batch_size
*
self
.
seq_len
//
page_size
+
1
)
*
page_size
# out_cache_end is the start of the next block
out_cache_end
=
out_cache_start
+
decode_len
*
page_size
else
:
out_cache_start
=
self
.
batch_size
*
self
.
seq_len
out_cache_end
=
self
.
batch_size
*
total_len
# Add token pool and KV cache
forward_batch
.
req_to_token_pool
=
MockReqToTokenPool
(
self
.
batch_size
,
curr_seq_len
,
self
.
device
)
forward_batch
.
token_to_kv_pool
=
self
.
_create_kv_pool
(
self
.
batch_size
*
curr_seq_len
)
# Pre-fill KV cache
cache_k
,
cache_v
,
_
=
self
.
_create_qkv_tensors
(
self
.
batch_size
*
self
.
seq_len
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
torch
.
arange
(
self
.
batch_size
*
self
.
seq_len
,
device
=
self
.
device
),
cache_k
,
cache_v
,
layer
.
k_scale
,
layer
.
v_scale
,
)
# Initialize forward metadata before running the attention
self
.
backend
.
init_forward_metadata
(
forward_batch
)
# Run forward_decode
output
=
self
.
backend
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
)
# Verify output
expected_shape
=
(
self
.
batch_size
,
self
.
num_heads
*
self
.
head_dim
)
self
.
_verify_output
(
output
,
expected_shape
)
def
test_forward_extend_with_prefix
(
self
):
"""Test extending from cached prefix tokens."""
# Define prefix and extend lengths
prefix_len
=
2
extend_len
=
2
total_len
=
prefix_len
+
extend_len
# Create test inputs for the extend portion
q
,
k
,
v
=
self
.
_create_qkv_tensors
(
self
.
batch_size
*
extend_len
)
# Create attention layer
layer
=
self
.
_create_attention_layer
()
# Create forward batch
forward_batch
=
ForwardBatch
(
forward_batch
=
ForwardBatch
(
batch_size
=
self
.
batch_size
,
batch_size
=
self
.
batch_size
,
input_ids
=
torch
.
randint
(
input_ids
=
torch
.
randint
(
0
,
100
,
(
self
.
batch_size
,
extend
_len
),
device
=
self
.
device
0
,
100
,
(
self
.
batch_size
,
decode
_len
),
device
=
self
.
device
),
),
out_cache_loc
=
torch
.
arange
(
out_cache_loc
=
torch
.
tensor
(
self
.
batch_size
*
prefix_len
,
[
out_cache_start
,
out_cache_end
],
device
=
self
.
device
self
.
batch_size
*
total_len
,
device
=
self
.
device
,
),
),
seq_lens_sum
=
self
.
batch_size
*
total_len
,
seq_lens_sum
=
self
.
batch_size
*
total_len
,
forward_mode
=
ForwardMode
.
EXTEND
,
forward_mode
=
mode
,
req_pool_indices
=
torch
.
arange
(
self
.
batch_size
,
device
=
self
.
device
),
req_pool_indices
=
torch
.
arange
(
self
.
batch_size
,
device
=
self
.
device
),
seq_lens
=
torch
.
tensor
([
total_len
]
*
self
.
batch_size
,
device
=
self
.
device
),
seq_lens
=
torch
.
tensor
(
extend_prefix_lens
=
torch
.
tensor
(
[
total_len
]
*
self
.
batch_size
,
device
=
self
.
device
[
prefix_len
]
*
self
.
batch_size
,
device
=
self
.
device
),
extend_seq_lens
=
torch
.
tensor
(
[
extend_len
]
*
self
.
batch_size
,
device
=
self
.
device
),
),
seq_lens_cpu
=
torch
.
tensor
([
total_len
]
*
self
.
batch_size
,
device
=
"cpu"
),
attn_backend
=
self
.
backend
,
attn_backend
=
self
.
backend
,
)
)
# Add token pool
and KV cache
# Add token pool
forward_batch
.
req_to_token_pool
=
MockReqToT
oken
P
ool
(
forward_batch
.
req_to_token_pool
=
self
.
model_runner
.
req_to_t
oken
_p
ool
self
.
batch_size
,
total_len
,
self
.
device
)
# Write current batch's req_to_token to req_to_token_pool
forward_batch
.
token_to_kv
_pool
=
self
.
_create_kv_pool
(
self
.
_mock_write_to_req_to_token
_pool
(
self
.
batch_size
,
total_len
,
page_size
)
self
.
batch_size
*
total_len
# Add kv pool for this forward batch
)
forward_batch
.
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
# Pre-fill the KV cache for prefix with known values
return
forward_batch
def
_setup_kv_cache
(
self
,
forward_batch
,
layer
,
cache_len
):
# Create constant values for the prefix cache for easy debugging
cache_k
=
torch
.
ones
(
cache_k
=
torch
.
ones
(
self
.
batch_size
*
prefix
_len
,
self
.
batch_size
*
cache
_len
,
self
.
num_heads
,
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
...
@@ -278,7 +256,7 @@ class TestFlashAttentionBackend(CustomTestCase):
...
@@ -278,7 +256,7 @@ class TestFlashAttentionBackend(CustomTestCase):
)
)
cache_v
=
(
cache_v
=
(
torch
.
ones
(
torch
.
ones
(
self
.
batch_size
*
prefix
_len
,
self
.
batch_size
*
cache
_len
,
self
.
num_heads
,
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
...
@@ -290,22 +268,82 @@ class TestFlashAttentionBackend(CustomTestCase):
...
@@ -290,22 +268,82 @@ class TestFlashAttentionBackend(CustomTestCase):
# Set the prefix KV cache
# Set the prefix KV cache
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
layer
,
torch
.
arange
(
self
.
batch_size
*
prefix
_len
,
device
=
self
.
device
),
torch
.
arange
(
self
.
batch_size
*
cache
_len
,
device
=
self
.
device
),
cache_k
,
cache_k
,
cache_v
,
cache_v
,
layer
.
k_scale
,
layer
.
k_scale
,
layer
.
v_scale
,
layer
.
v_scale
,
)
)
# Initialize forward metadata before running the attention
def
_run_attention_test
(
self
,
mode
,
q_len
,
prefix_len
=
0
,
page_size
=
1
):
"""
Run an attention test with the specified parameters.
Args:
mode: ForwardMode.EXTEND or ForwardMode.DECODE
q_len: Length of the query sequence. For decode mode, q_len is 1.
prefix_len: Length of the prefix sequence for extend mode
page_size: Page size for the KV cache
"""
layer
=
self
.
_create_attention_layer
()
# Create forward batch and set up
forward_batch
=
self
.
_create_forward_batch
(
mode
,
q_len
,
prefix_len
,
page_size
)
# Create QKV tensors for the input
q
,
k
,
v
=
self
.
_create_qkv_tensors
(
self
.
batch_size
*
q_len
)
# KV cache for prefixed extend is prefix_len
# KV cache for decode is same as seq_len
# No KV cache for extend without prefix
if
mode
==
ForwardMode
.
EXTEND
:
if
prefix_len
>
0
:
self
.
_setup_kv_cache
(
forward_batch
,
layer
,
prefix_len
)
else
:
self
.
_setup_kv_cache
(
forward_batch
,
layer
,
self
.
seq_len
)
self
.
backend
.
init_forward_metadata
(
forward_batch
)
self
.
backend
.
init_forward_metadata
(
forward_batch
)
# Run forward_extend
if
mode
==
ForwardMode
.
EXTEND
:
expected_shape
=
(
self
.
batch_size
*
q_len
,
self
.
num_heads
*
self
.
head_dim
,
)
output
=
self
.
backend
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
)
output
=
self
.
backend
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
)
else
:
expected_shape
=
(
self
.
batch_size
,
self
.
num_heads
*
self
.
head_dim
)
output
=
self
.
backend
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
)
output_ref
=
self
.
_run_reference_forward
(
mode
,
q
,
k
,
v
,
layer
,
forward_batch
,
expected_shape
)
self
.
_verify_output
(
output
,
expected_shape
,
output_ref
)
return
output
def
test_forward_extend
(
self
):
"""Test the standard extend operation."""
self
.
_run_attention_test
(
ForwardMode
.
EXTEND
,
q_len
=
self
.
seq_len
)
def
test_forward_decode
(
self
):
"""Test the decode operation with cached tokens."""
self
.
_run_attention_test
(
ForwardMode
.
DECODE
,
q_len
=
1
)
def
test_forward_extend_with_prefix
(
self
):
"""Test extending from cached prefix tokens."""
prefix_len
=
self
.
seq_len
//
2
extend_len
=
self
.
seq_len
-
prefix_len
self
.
_run_attention_test
(
ForwardMode
.
EXTEND
,
q_len
=
extend_len
,
prefix_len
=
prefix_len
)
def
test_forward_extend_with_page_size_greater_than_1
(
self
):
"""Test extending from cached prefix tokens with page size greater than 1."""
self
.
_run_attention_test
(
ForwardMode
.
EXTEND
,
q_len
=
self
.
seq_len
,
page_size
=
64
)
# Verify output
def
test_forward_decode_with_page_size_greater_than_1
(
self
):
expected_shape
=
(
self
.
batch_size
*
extend_len
,
self
.
num_heads
*
self
.
head_dim
)
"""Test decode operation with page size greater than 1."""
self
.
_
verify_output
(
output
,
expected_shape
)
self
.
_
run_attention_test
(
ForwardMode
.
DECODE
,
q_len
=
1
,
page_size
=
64
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
python/sglang/test/attention/test_flashattn_mla_backend.py
0 → 100644
View file @
804d9f2e
import
unittest
import
torch
from
sglang.srt.configs.model_config
import
AttentionArch
from
sglang.srt.layers.attention.flashattention_backend
import
FlashAttentionBackend
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.mem_cache.memory_pool
import
MLATokenToKVPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.test.test_utils
import
CustomTestCase
class
MockModelRunner
:
def
__init__
(
self
,
kv_lora_rank
,
qk_rope_head_dim
,
):
attention_arch
=
AttentionArch
.
MLA
self
.
device
=
"cuda"
self
.
dtype
=
torch
.
float16
context_len
=
2048
self
.
model_config
=
type
(
"ModelConfig"
,
(),
{
"context_len"
:
context_len
,
"attention_arch"
:
attention_arch
,
},
)
self
.
sliding_window_size
=
None
batch_size
=
160
# Create a proper req_to_token_pool with the req_to_token attribute
self
.
req_to_token_pool
=
type
(
"TokenPool"
,
(),
{
# A typical max_bs * max_context_len for cuda graph decode
"size"
:
batch_size
,
# Add req_to_token attribute
"req_to_token"
:
torch
.
zeros
(
batch_size
,
context_len
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
},
)
self
.
page_size
=
1
max_total_num_tokens
=
batch_size
*
context_len
self
.
token_to_kv_pool
=
MLATokenToKVPool
(
size
=
max_total_num_tokens
,
page_size
=
self
.
page_size
,
dtype
=
self
.
dtype
,
kv_lora_rank
=
kv_lora_rank
,
qk_rope_head_dim
=
qk_rope_head_dim
,
layer_num
=
1
,
# only consider layer=1 for unit test
device
=
self
.
device
,
enable_memory_saver
=
False
,
)
class
MockReqToTokenPool
:
def
__init__
(
self
,
batch_size
,
seq_len
,
device
):
self
.
req_to_token
=
(
torch
.
arange
(
batch_size
*
seq_len
,
device
=
device
)
.
reshape
(
batch_size
,
seq_len
)
.
to
(
torch
.
int32
)
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
"Test requires CUDA"
)
class
TestFlashAttentionMLABackend
(
CustomTestCase
):
def
setUp
(
self
):
# Test parameters
self
.
batch_size
=
2
self
.
seq_len
=
360
self
.
num_heads
=
2
self
.
device
=
"cuda"
self
.
dtype
=
torch
.
float16
self
.
kv_lora_rank
=
512
self
.
q_lora_rank
=
128
self
.
qk_rope_head_dim
=
64
self
.
qk_head_dim
=
self
.
qk_rope_head_dim
+
self
.
kv_lora_rank
# Assume no rope scaling
self
.
scaling
=
self
.
qk_head_dim
**-
0.5
# Initialize model runner and backend
self
.
_init_model_runner
()
self
.
backend
=
FlashAttentionBackend
(
self
.
model_runner
)
self
.
num_local_heads
=
2
def
_init_model_runner
(
self
):
self
.
model_runner
=
MockModelRunner
(
kv_lora_rank
=
self
.
kv_lora_rank
,
qk_rope_head_dim
=
self
.
qk_rope_head_dim
,
)
self
.
backend
=
FlashAttentionBackend
(
self
.
model_runner
)
def
_create_attention_layer
(
self
):
"""Create attention layer for testing."""
self
.
attn_mqa
=
RadixAttention
(
num_heads
=
self
.
num_local_heads
,
head_dim
=
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
scaling
=
self
.
scaling
,
num_kv_heads
=
1
,
layer_id
=
0
,
v_head_dim
=
self
.
kv_lora_rank
,
prefix
=
"attn_mqa"
,
)
return
self
.
attn_mqa
def
_run_reference_forward
(
self
,
mode
,
q
,
k
,
v
,
layer
,
forward_batch
,
expected_shape
):
"""Run reference forward pass using native backend."""
if
mode
==
ForwardMode
.
EXTEND
:
output
=
self
.
ref_backend
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
)
else
:
# ForwardMode.DECODE
output
=
self
.
ref_backend
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
)
return
output
.
view
(
expected_shape
)
def
_verify_output
(
self
,
output
,
expected_shape
):
"""Verify output tensor shape, dtype, and values."""
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
f
"Expected shape
{
expected_shape
}
, got
{
output
.
shape
}
"
,
)
self
.
assertEqual
(
output
.
dtype
,
self
.
dtype
)
self
.
assertEqual
(
output
.
device
.
type
,
"cuda"
)
self
.
assertEqual
(
torch
.
isnan
(
output
).
sum
().
item
(),
0
,
"Output contains NaN values"
)
def
_create_forward_batch
(
self
,
mode
,
q_len
=
None
,
prefix_len
=
0
):
"""Create a forward batch for testing based on mode and lengths."""
# Default to self.seq_len if not specified
q_len
=
q_len
or
self
.
seq_len
if
mode
==
ForwardMode
.
EXTEND
:
total_len
=
prefix_len
+
q_len
out_cache_start
=
prefix_len
*
self
.
batch_size
out_cache_end
=
total_len
*
self
.
batch_size
forward_batch
=
ForwardBatch
(
batch_size
=
self
.
batch_size
,
input_ids
=
torch
.
randint
(
0
,
100
,
(
self
.
batch_size
,
q_len
),
device
=
self
.
device
),
out_cache_loc
=
torch
.
arange
(
out_cache_start
,
out_cache_end
,
device
=
self
.
device
),
seq_lens_sum
=
self
.
batch_size
*
total_len
,
forward_mode
=
mode
,
req_pool_indices
=
torch
.
arange
(
self
.
batch_size
,
device
=
self
.
device
),
seq_lens
=
torch
.
tensor
(
[
total_len
]
*
self
.
batch_size
,
device
=
self
.
device
),
seq_lens_cpu
=
torch
.
tensor
([
total_len
]
*
self
.
batch_size
,
device
=
"cpu"
),
extend_prefix_lens
=
torch
.
tensor
(
[
prefix_len
]
*
self
.
batch_size
,
device
=
self
.
device
),
extend_prefix_lens_cpu
=
torch
.
tensor
(
[
prefix_len
]
*
self
.
batch_size
,
device
=
"cpu"
),
extend_seq_lens
=
torch
.
tensor
(
[
q_len
]
*
self
.
batch_size
,
device
=
self
.
device
),
extend_seq_lens_cpu
=
torch
.
tensor
(
[
q_len
]
*
self
.
batch_size
,
device
=
"cpu"
),
attn_backend
=
self
.
backend
,
)
else
:
# ForwardMode.DECODE
decode_len
=
q_len
# typically 1 for decode mode
total_len
=
self
.
seq_len
+
decode_len
out_cache_start
=
self
.
batch_size
*
self
.
seq_len
out_cache_end
=
self
.
batch_size
*
total_len
forward_batch
=
ForwardBatch
(
batch_size
=
self
.
batch_size
,
input_ids
=
torch
.
randint
(
0
,
100
,
(
self
.
batch_size
,
decode_len
),
device
=
self
.
device
),
out_cache_loc
=
torch
.
arange
(
out_cache_start
,
out_cache_end
,
device
=
self
.
device
),
seq_lens_sum
=
self
.
batch_size
*
total_len
,
forward_mode
=
mode
,
req_pool_indices
=
torch
.
arange
(
self
.
batch_size
,
device
=
self
.
device
),
seq_lens
=
torch
.
tensor
(
[
total_len
]
*
self
.
batch_size
,
device
=
self
.
device
),
seq_lens_cpu
=
torch
.
tensor
([
total_len
]
*
self
.
batch_size
,
device
=
"cpu"
),
attn_backend
=
self
.
backend
,
)
# Add token pool from model runner to forward batch
forward_batch
.
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
# Add KV cache from model runner to forward batch
forward_batch
.
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
return
forward_batch
def
_setup_kv_cache
(
self
,
forward_batch
,
layer
,
cache_len
):
"""Set up KV cache with prefix tokens."""
if
cache_len
<=
0
:
return
# Create constant values for the prefix cache for easy debugging
latent_cache
=
torch
.
ones
(
self
.
batch_size
*
cache_len
,
1
,
# latent cache has only one head in MQA
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
)
# Set the prefix KV cache
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
torch
.
arange
(
self
.
batch_size
*
cache_len
,
device
=
self
.
device
),
latent_cache
,
None
,
)
def
_run_attention_test
(
self
,
mode
,
q_len
,
prefix_len
=
0
):
"""
Run an attention test with the specified parameters.
Args:
mode: ForwardMode.EXTEND or ForwardMode.DECODE
q_len: Length of the query sequence. For decode mode, q_len is 1.
prefix_len: Length of the prefix sequence for extend mode
"""
layer
=
self
.
_create_attention_layer
()
# Create forward batch and set up
forward_batch
=
self
.
_create_forward_batch
(
mode
,
q_len
,
prefix_len
)
# Create q, kv_compressed for testing
q_shape
=
(
self
.
batch_size
*
q_len
,
self
.
num_heads
,
self
.
qk_head_dim
)
kv_shape
=
(
self
.
batch_size
*
q_len
,
self
.
qk_head_dim
)
q
=
torch
.
randn
(
q_shape
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
kv_compressed
=
torch
.
randn
(
kv_shape
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
# v is not used for mqa, all values passed in through k
k
=
kv_compressed
.
unsqueeze
(
1
)
v
=
torch
.
randn
((
1
),
dtype
=
self
.
dtype
,
device
=
self
.
device
)
self
.
_setup_kv_cache
(
forward_batch
,
layer
,
prefix_len
)
self
.
backend
.
init_forward_metadata
(
forward_batch
)
expected_shape
=
(
self
.
batch_size
*
q_len
,
self
.
num_heads
*
self
.
kv_lora_rank
,
)
if
mode
==
ForwardMode
.
EXTEND
:
output
=
self
.
backend
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
)
else
:
output
=
self
.
backend
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
)
self
.
_verify_output
(
output
,
expected_shape
)
return
output
def
test_forward_extend
(
self
):
"""Test the standard extend operation."""
self
.
_run_attention_test
(
ForwardMode
.
EXTEND
,
q_len
=
self
.
seq_len
)
def
test_forward_decode
(
self
):
"""Test the decode operation with cached tokens."""
self
.
_run_attention_test
(
ForwardMode
.
DECODE
,
q_len
=
1
)
def
test_forward_extend_with_prefix
(
self
):
"""Test extending from cached prefix tokens."""
prefix_len
=
self
.
seq_len
//
2
extend_len
=
self
.
seq_len
-
prefix_len
self
.
_run_attention_test
(
ForwardMode
.
EXTEND
,
q_len
=
extend_len
,
prefix_len
=
prefix_len
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/run_suite.py
View file @
804d9f2e
...
@@ -28,6 +28,7 @@ suites = {
...
@@ -28,6 +28,7 @@ suites = {
TestFile
(
"test_chunked_prefill.py"
,
336
),
TestFile
(
"test_chunked_prefill.py"
,
336
),
TestFile
(
"test_eagle_infer.py"
,
500
),
TestFile
(
"test_eagle_infer.py"
,
500
),
TestFile
(
"test_ebnf_constrained.py"
),
TestFile
(
"test_ebnf_constrained.py"
),
TestFile
(
"test_fa3.py"
,
5
),
TestFile
(
"test_fp8_kernel.py"
,
8
),
TestFile
(
"test_fp8_kernel.py"
,
8
),
TestFile
(
"test_embedding_openai_server.py"
,
36
),
TestFile
(
"test_embedding_openai_server.py"
,
36
),
TestFile
(
"test_hidden_states.py"
,
55
),
TestFile
(
"test_hidden_states.py"
,
55
),
...
...
test/srt/test_fa3.py
0 → 100644
View file @
804d9f2e
import
unittest
from
types
import
SimpleNamespace
import
requests
import
torch
from
sglang.srt.utils
import
get_device_sm
,
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_utils
import
(
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
DEFAULT_MLA_MODEL_NAME_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
"""
Integration test for python/sglang/srt/layers/attention/flashattention_backend.py
"""
# Change to your own model if testing model is not public.
MODEL_USED_FOR_TEST
=
DEFAULT_MODEL_NAME_FOR_TEST
MODEL_USED_FOR_TEST_MLA
=
DEFAULT_MLA_MODEL_NAME_FOR_TEST
# Setting data path to None uses default data path in few_shot_gsm8k eval test.
DATA_PATH
=
None
@
unittest
.
skipIf
(
get_device_sm
()
<
90
,
"Test requires CUDA SM 90 or higher"
)
class
BaseFlashAttentionTest
(
unittest
.
TestCase
):
"""Base class for FlashAttention tests to reduce code duplication."""
model
=
MODEL_USED_FOR_TEST
base_url
=
DEFAULT_URL_FOR_TEST
accuracy_threshold
=
0.62
@
classmethod
def
get_server_args
(
cls
):
"""Return the arguments for the server launch. Override in subclasses."""
args
=
[
"--trust-remote-code"
,
"--enable-torch-compile"
,
"--attention-backend"
,
"fa3"
,
]
return
args
@
classmethod
def
setUpClass
(
cls
):
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
cls
.
get_server_args
(),
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
data_path
=
DATA_PATH
,
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
metrics
)
# Use the appropriate metric key based on the test class
metric_key
=
"accuracy"
self
.
assertGreater
(
metrics
[
metric_key
],
self
.
accuracy_threshold
)
class
TestFlashAttention3
(
BaseFlashAttentionTest
):
"""Test FlashAttention3 with MLA model and CUDA graph enabled."""
@
classmethod
def
get_server_args
(
cls
):
args
=
super
().
get_server_args
()
args
.
extend
(
[
"--cuda-graph-max-bs"
,
"2"
,
]
)
return
args
class
TestFlashAttention3DisableCudaGraph
(
BaseFlashAttentionTest
):
"""Test FlashAttention3 with CUDA graph disabled."""
@
classmethod
def
get_server_args
(
cls
):
args
=
super
().
get_server_args
()
args
.
extend
(
[
"--disable-cuda-graph"
,
]
)
return
args
class
TestFlashAttention3MLA
(
BaseFlashAttentionTest
):
"""Test FlashAttention3 with MLA."""
model
=
MODEL_USED_FOR_TEST_MLA
@
classmethod
def
get_server_args
(
cls
):
args
=
super
().
get_server_args
()
args
.
extend
(
[
"--cuda-graph-max-bs"
,
"2"
,
]
)
return
args
class
TestFlashAttention3SpeculativeDecode
(
BaseFlashAttentionTest
):
"""Test FlashAttention3 with speculative decode enabled."""
model
=
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
@
classmethod
def
get_server_args
(
cls
):
args
=
super
().
get_server_args
()
args
.
extend
(
[
"--cuda-graph-max-bs"
,
"2"
,
"--speculative-algorithm"
,
"EAGLE3"
,
"--speculative-draft"
,
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
"--speculative-num-steps"
,
"3"
,
"--speculative-eagle-topk"
,
"1"
,
"--speculative-num-draft-tokens"
,
"3"
,
"--dtype"
,
"float16"
,
]
)
return
args
def
test_gsm8k
(
self
):
"""
Override the test_gsm8k to further test for average speculative accept length.
"""
requests
.
get
(
self
.
base_url
+
"/flush_cache"
)
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
DATA_PATH
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
metrics
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.60
)
server_info
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
avg_spec_accept_length
=
server_info
.
json
()[
"avg_spec_accept_length"
]
print
(
f
"
{
avg_spec_accept_length
=
}
"
)
self
.
assertGreater
(
avg_spec_accept_length
,
1.5
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment