Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3bface15
Unverified
Commit
3bface15
authored
Apr 17, 2025
by
woodx
Committed by
GitHub
Apr 17, 2025
Browse files
Feat/support encoder model (like bert) (#4887)
parent
6fb29ffd
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
593 additions
and
3 deletions
+593
-3
python/sglang/srt/layers/attention/torch_native_backend.py
python/sglang/srt/layers/attention/torch_native_backend.py
+6
-1
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+6
-0
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+13
-2
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+15
-0
python/sglang/srt/models/bert.py
python/sglang/srt/models/bert.py
+398
-0
python/sglang/test/runners.py
python/sglang/test/runners.py
+4
-0
test/srt/models/test_encoder_embedding_models.py
test/srt/models/test_encoder_embedding_models.py
+149
-0
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+2
-0
No files found.
python/sglang/srt/layers/attention/torch_native_backend.py
View file @
3bface15
...
...
@@ -6,6 +6,7 @@ import torch
from
torch.nn.functional
import
scaled_dot_product_attention
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.radix_attention
import
AttentionType
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
if
TYPE_CHECKING
:
...
...
@@ -202,6 +203,10 @@ class TorchNativeAttnBackend(AttentionBackend):
q_
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
)
o_
=
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
)
causal
=
True
if
layer
.
is_cross_attention
or
layer
.
attn_type
==
AttentionType
.
ENCODER_ONLY
:
causal
=
False
self
.
_run_sdpa_forward_extend
(
q_
,
o_
,
...
...
@@ -214,7 +219,7 @@ class TorchNativeAttnBackend(AttentionBackend):
forward_batch
.
extend_seq_lens
,
scaling
=
layer
.
scaling
,
enable_gqa
=
use_gqa
,
causal
=
not
layer
.
is_cross_attention
,
causal
=
causal
,
)
return
o
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
3bface15
...
...
@@ -10,6 +10,7 @@ import triton.language as tl
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.radix_attention
import
AttentionType
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
get_bool_env_var
,
get_device_core_count
...
...
@@ -528,6 +529,10 @@ class TritonAttnBackend(AttentionBackend):
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
causal
=
True
if
layer
.
attn_type
==
AttentionType
.
ENCODER_ONLY
:
causal
=
False
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
...
...
@@ -539,6 +544,7 @@ class TritonAttnBackend(AttentionBackend):
self
.
forward_metadata
.
kv_indptr
,
self
.
forward_metadata
.
kv_indices
,
self
.
forward_metadata
.
custom_mask
,
causal
,
self
.
forward_metadata
.
mask_indptr
,
self
.
forward_metadata
.
max_extend_len
,
layer
.
scaling
,
...
...
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
3bface15
...
...
@@ -74,6 +74,7 @@ def _fwd_kernel(
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
USE_CUSTOM_MASK
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
SKIP_PREFIX_CUSTOM_MASK
:
tl
.
constexpr
,
STORE_TRANSPOSE
:
tl
.
constexpr
,
):
...
...
@@ -129,6 +130,7 @@ def _fwd_kernel(
for
start_n
in
range
(
0
,
cur_seq_len_prefix
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
mask_n
=
(
start_n
+
offs_n
)
<
cur_seq_len_prefix
offs_kv_loc
=
tl
.
load
(
kv_indices
+
cur_seq_kv_start_idx
+
start_n
+
offs_n
,
mask
=
mask_n
,
other
=
0
)
...
...
@@ -196,7 +198,11 @@ def _fwd_kernel(
# stage 2: compute the triangle part
cur_block_m_end
=
tl
.
minimum
(
cur_seq_len_extend
,
(
cur_block_m
+
1
)
*
BLOCK_M
)
cur_block_m_end
=
(
cur_seq_len_extend
if
not
IS_CAUSAL
else
tl
.
minimum
(
cur_seq_len_extend
,
(
cur_block_m
+
1
)
*
BLOCK_M
)
)
for
start_n
in
range
(
0
,
cur_block_m_end
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
mask_n
=
(
start_n
+
offs_n
)
<
cur_block_m_end
...
...
@@ -243,12 +249,15 @@ def _fwd_kernel(
)
custom_mask
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
qk
=
tl
.
where
(
custom_mask
,
qk
,
float
(
"-inf"
))
el
se
:
el
if
IS_CAUSAL
:
mask_causual
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
>=
(
start_n
+
offs_n
[
None
,
:]
)
mask_causual
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
qk
=
tl
.
where
(
mask_causual
,
qk
,
float
(
"-inf"
))
else
:
mask_non_causal
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
qk
=
tl
.
where
(
mask_non_causal
,
qk
,
float
(
"-inf"
))
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
...
...
@@ -299,6 +308,7 @@ def extend_attention_fwd(
kv_indptr
,
kv_indices
,
custom_mask
,
is_causal
,
mask_indptr
,
max_len_extend
,
sm_scale
=
None
,
...
...
@@ -411,6 +421,7 @@ def extend_attention_fwd(
Lq
=
Lq
,
Lv
=
Lv
,
USE_CUSTOM_MASK
=
USE_CUSTOM_MASK
,
IS_CAUSAL
=
is_causal
,
SKIP_PREFIX_CUSTOM_MASK
=
SKIP_PREFIX_CUSTOM_MASK
,
STORE_TRANSPOSE
=
_is_hip
,
num_warps
=
num_warps
,
...
...
python/sglang/srt/layers/radix_attention.py
View file @
3bface15
...
...
@@ -13,6 +13,7 @@
# ==============================================================================
"""Radix attention."""
from
enum
import
Enum
from
typing
import
Optional
from
torch
import
nn
...
...
@@ -22,6 +23,18 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
AttentionType
(
Enum
):
"""
Attention type.
Use string to be compatible with `torch.compile`.
"""
# Decoder attention between previous layer Q/K/V
DECODER
=
"decoder"
# Encoder attention between previous layer Q/K/V
ENCODER_ONLY
=
"encoder_only"
class
RadixAttention
(
nn
.
Module
):
"""
The attention layer implementation.
...
...
@@ -39,6 +52,7 @@ class RadixAttention(nn.Module):
sliding_window_size
:
int
=
-
1
,
is_cross_attention
:
bool
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
attn_type
=
AttentionType
.
DECODER
,
prefix
:
str
=
""
,
use_irope
:
bool
=
False
,
):
...
...
@@ -64,6 +78,7 @@ class RadixAttention(nn.Module):
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
if
self
.
quant_method
is
not
None
:
self
.
quant_method
.
create_weights
(
self
)
self
.
attn_type
=
attn_type
def
forward
(
self
,
...
...
python/sglang/srt/models/bert.py
0 → 100644
View file @
3bface15
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
import
torch
from
torch
import
nn
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.activation
import
get_act_fn
from
sglang.srt.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.pooler
import
EmbeddingPoolerOutput
,
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
AttentionType
,
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
BertConfig
=
None
class
BertEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BertConfig
):
super
().
__init__
()
self
.
size
=
config
.
hidden_size
self
.
word_embeddings
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
position_embeddings
=
VocabParallelEmbedding
(
config
.
max_position_embeddings
,
config
.
hidden_size
)
self
.
token_type_embeddings
=
VocabParallelEmbedding
(
config
.
type_vocab_size
,
config
.
hidden_size
)
self
.
LayerNorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
position_ids
=
nn
.
Parameter
(
torch
.
empty
((
1
,
config
.
max_position_embeddings
)),
)
self
.
position_embedding_type
=
config
.
position_embedding_type
if
self
.
position_embedding_type
!=
"absolute"
:
raise
ValueError
(
"Only 'absolute' position_embedding_type"
+
" is supported"
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
input_shape
=
input_ids
.
size
()
# Input embeddings.
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
# Position embeddings.
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
token_type_ids
=
torch
.
zeros
(
input_shape
,
dtype
=
torch
.
long
,
device
=
inputs_embeds
.
device
)
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
embeddings
=
inputs_embeds
+
token_type_embeddings
+
position_embeddings
embeddings
=
self
.
LayerNorm
(
embeddings
)
return
embeddings
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BertConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
layer
=
nn
.
ModuleList
(
[
BertLayer
(
config
=
config
,
layer_id
=
layer_idx
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.layer.
{
layer_idx
}
"
,
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
]
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
)
->
torch
.
Tensor
:
for
layer
in
self
.
layer
:
hidden_states
=
layer
(
hidden_states
,
forward_batch
)
return
hidden_states
class
BertLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BertConfig
,
layer_id
:
int
=
0
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
attention
=
BertAttention
(
hidden_size
=
config
.
hidden_size
,
num_attention_heads
=
config
.
num_attention_heads
,
layer_id
=
layer_id
,
layer_norm_eps
=
config
.
layer_norm_eps
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attention"
,
)
self
.
intermediate
=
BertIntermediate
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.intermediate"
,
)
self
.
output
=
BertOutput
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
layer_norm_eps
=
config
.
layer_norm_eps
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.output"
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
):
attn_output
=
self
.
attention
(
hidden_states
,
forward_batch
)
intermediate_output
=
self
.
intermediate
(
attn_output
)
output
=
self
.
output
(
intermediate_output
,
attn_output
)
return
output
class
BertAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_attention_heads
:
int
,
layer_norm_eps
:
float
,
layer_id
:
int
=
0
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
self_attn
=
BertSelfAttention
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.output"
,
)
self
.
output
=
BertSelfOutput
(
hidden_size
=
hidden_size
,
layer_norm_eps
=
layer_norm_eps
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.output"
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
)
->
torch
.
Tensor
:
self_output
=
self
.
self_attn
(
hidden_states
,
forward_batch
)
return
self
.
output
(
self_output
,
hidden_states
)
class
BertSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_attention_heads
:
int
,
layer_id
:
int
=
0
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_attention_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
self
.
total_num_heads
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
assert
self
.
head_dim
*
self
.
total_num_heads
==
self
.
hidden_size
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
self
.
hidden_size
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
total_num_heads
,
total_num_kv_heads
=
self
.
total_num_kv_heads
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
attn
=
RadixAttention
(
num_heads
=
self
.
num_heads
,
head_dim
=
self
.
head_dim
,
scaling
=
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
AttentionType
.
ENCODER_ONLY
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
return
output
class
BertSelfOutput
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
layer_norm_eps
:
float
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
dense
=
RowParallelLinear
(
input_size
=
hidden_size
,
output_size
=
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.dense"
,
)
self
.
LayerNorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
layer_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
input_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
LayerNorm
(
hidden_states
+
input_tensor
)
return
hidden_states
class
BertIntermediate
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
dense
=
ColumnParallelLinear
(
input_size
=
hidden_size
,
output_size
=
intermediate_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.dense"
,
)
self
.
intermediate_act_fn
=
get_act_fn
(
hidden_act
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
intermediate_act_fn
(
hidden_states
)
return
hidden_states
class
BertOutput
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
layer_norm_eps
:
float
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
dense
=
RowParallelLinear
(
input_size
=
intermediate_size
,
output_size
=
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.dense"
,
)
self
.
LayerNorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
layer_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
input_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
LayerNorm
(
hidden_states
+
input_tensor
)
return
hidden_states
class
BertModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
config
:
BertConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
self
.
embeddings
=
BertEmbedding
(
config
)
self
.
encoder
=
BertEncoder
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
f
"encoder"
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
# self.pooler = BertPooler(config)
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
False
,
)
->
torch
.
Tensor
:
assert
get_embedding
==
True
# Your tokenized IDs
hidden_states
=
self
.
embeddings
(
input_ids
=
input_ids
,
position_ids
=
positions
,
)
hidden_states
=
self
.
encoder
(
hidden_states
,
forward_batch
=
forward_batch
)
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"query"
,
"q"
),
(
"qkv_proj"
,
"key"
,
"k"
),
(
"qkv_proj"
,
"value"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
name
=
name
.
replace
(
"self"
,
"self_attn"
)
if
"pooler"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
class
Contriever
(
BertModel
):
pass
EntryClass
=
[
BertModel
,
Contriever
]
python/sglang/test/runners.py
View file @
3bface15
...
...
@@ -51,6 +51,8 @@ NUM_TOP_LOGPROBS = 5
def
get_dtype_str
(
torch_dtype
):
if
torch_dtype
is
torch
.
float16
:
return
"float16"
if
torch_dtype
is
torch
.
float32
:
return
"float32"
else
:
raise
NotImplementedError
()
...
...
@@ -447,6 +449,7 @@ class SRTRunner:
port
:
int
=
DEFAULT_PORT_FOR_SRT_TEST_RUNNER
,
lora_paths
:
List
[
str
]
=
None
,
max_loras_per_batch
:
int
=
4
,
attention_backend
:
Optional
[
str
]
=
None
,
lora_backend
:
str
=
"triton"
,
disable_cuda_graph
:
bool
=
False
,
disable_radix_cache
:
bool
=
False
,
...
...
@@ -487,6 +490,7 @@ class SRTRunner:
lora_paths
=
lora_paths
,
max_loras_per_batch
=
max_loras_per_batch
,
lora_backend
=
lora_backend
,
attention_backend
=
attention_backend
,
disable_cuda_graph
=
disable_cuda_graph
,
disable_radix_cache
=
disable_radix_cache
,
chunked_prefill_size
=
chunked_prefill_size
,
...
...
test/srt/models/test_encoder_embedding_models.py
0 → 100644
View file @
3bface15
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# python -m unittest test_encoder_embedding_models.TestEncoderEmbeddingModels.test_prefill_logits
import
multiprocessing
as
mp
import
random
import
time
import
unittest
import
torch
from
transformers
import
AutoConfig
,
AutoTokenizer
from
sglang.test.runners
import
DEFAULT_PROMPTS
,
HFRunner
,
SRTRunner
from
sglang.test.test_utils
import
CustomTestCase
,
get_similarities
,
is_in_ci
MODELS
=
[(
"BAAI/bge-small-en"
,
1
,
1e-5
)]
ATTENTION_BACKEND
=
[
"torch_native"
,
"triton"
]
BATCH_SIZE
=
[
30
]
TORCH_DTYPES
=
[
torch
.
float32
]
sgl_to_st_ratio
=
[]
class
TestEncoderEmbeddingModels
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
def
_truncate_prompts
(
self
,
prompts
,
model_path
):
config
=
AutoConfig
.
from_pretrained
(
model_path
)
max_length
=
getattr
(
config
,
"max_position_embeddings"
,
512
)
-
20
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
)
truncated_prompts
=
[]
for
prompt
in
prompts
:
tokens
=
tokenizer
(
prompt
,
return_tensors
=
"pt"
,
truncation
=
False
)
if
len
(
tokens
.
input_ids
[
0
])
>
max_length
:
truncated_text
=
tokenizer
.
decode
(
tokens
.
input_ids
[
0
][:
max_length
-
1
],
skip_special_tokens
=
True
)
truncated_prompts
.
append
(
truncated_text
)
else
:
truncated_prompts
.
append
(
prompt
)
return
truncated_prompts
def
assert_close_prefill_logits
(
self
,
prompts
,
model_path
,
tp_size
,
torch_dtype
,
prefill_tolerance
,
attention_backend
,
batch_size
,
)
->
None
:
truncated_prompts
=
self
.
_truncate_prompts
(
prompts
,
model_path
)
truncated_prompts
=
truncated_prompts
*
batch_size
with
HFRunner
(
model_path
,
torch_dtype
=
torch_dtype
,
model_type
=
"embedding"
,
)
as
hf_runner
:
# warm up
hf_outputs
=
hf_runner
.
forward
(
truncated_prompts
)
st_start_time
=
time
.
time
()
hf_outputs
=
hf_runner
.
forward
(
truncated_prompts
)
st_end_time
=
time
.
time
()
with
SRTRunner
(
model_path
,
tp_size
=
tp_size
,
torch_dtype
=
torch_dtype
,
model_type
=
"embedding"
,
attention_backend
=
attention_backend
,
chunked_prefill_size
=-
1
,
disable_radix_cache
=
True
,
)
as
srt_runner
:
# warm up
srt_outputs
=
srt_runner
.
forward
(
truncated_prompts
)
sgl_start_time
=
time
.
time
()
srt_outputs
=
srt_runner
.
forward
(
truncated_prompts
)
sgl_end_time
=
time
.
time
()
transformer_time
=
st_end_time
-
st_start_time
sgl_time
=
sgl_end_time
-
sgl_start_time
sgl_to_st_ratio
.
append
(
sgl_time
/
transformer_time
)
for
i
in
range
(
len
(
truncated_prompts
)):
hf_logits
=
torch
.
Tensor
(
hf_outputs
.
embed_logits
[
i
])
srt_logits
=
torch
.
Tensor
(
srt_outputs
.
embed_logits
[
i
])
similarity
=
torch
.
tensor
(
get_similarities
(
hf_logits
,
srt_logits
))
# If something is wrong, uncomment this to observe similarity.
# print("similarity diff", abs(similarity - 1))
if
len
(
truncated_prompts
[
i
])
<=
1000
:
assert
torch
.
all
(
abs
(
similarity
-
1
)
<
prefill_tolerance
),
"embeddings are not all close"
def
test_prefill_logits
(
self
):
models_to_test
=
MODELS
if
is_in_ci
():
models_to_test
=
[
random
.
choice
(
MODELS
)]
for
model
,
tp_size
,
prefill_tolerance
in
models_to_test
:
for
attention_backend
in
ATTENTION_BACKEND
:
for
batch_size
in
BATCH_SIZE
:
for
torch_dtype
in
TORCH_DTYPES
:
self
.
assert_close_prefill_logits
(
DEFAULT_PROMPTS
,
model
,
tp_size
,
torch_dtype
,
prefill_tolerance
,
attention_backend
,
batch_size
,
)
for
i
in
range
(
len
(
BATCH_SIZE
)):
print
(
"bacth size: "
,
BATCH_SIZE
[
i
]
*
5
,
"sgl_time/st_time"
,
round
(
sgl_to_st_ratio
[
i
],
3
),
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_triton_attention_kernels.py
View file @
3bface15
...
...
@@ -116,6 +116,7 @@ class TestTritonAttention(CustomTestCase):
kv_indptr
,
kv_indices
,
custom_mask
,
True
,
mask_indptr
,
max_len_extend
,
)
...
...
@@ -150,6 +151,7 @@ class TestTritonAttention(CustomTestCase):
kv_indptr
,
kv_indices
,
custom_mask
,
True
,
mask_indptr
,
max_len_extend
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment