Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3bface15
Unverified
Commit
3bface15
authored
Apr 17, 2025
by
woodx
Committed by
GitHub
Apr 17, 2025
Browse files
Feat/support encoder model (like bert) (#4887)
parent
6fb29ffd
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
593 additions
and
3 deletions
+593
-3
python/sglang/srt/layers/attention/torch_native_backend.py
python/sglang/srt/layers/attention/torch_native_backend.py
+6
-1
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+6
-0
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+13
-2
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+15
-0
python/sglang/srt/models/bert.py
python/sglang/srt/models/bert.py
+398
-0
python/sglang/test/runners.py
python/sglang/test/runners.py
+4
-0
test/srt/models/test_encoder_embedding_models.py
test/srt/models/test_encoder_embedding_models.py
+149
-0
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+2
-0
No files found.
python/sglang/srt/layers/attention/torch_native_backend.py
View file @
3bface15
...
@@ -6,6 +6,7 @@ import torch
...
@@ -6,6 +6,7 @@ import torch
from
torch.nn.functional
import
scaled_dot_product_attention
from
torch.nn.functional
import
scaled_dot_product_attention
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.radix_attention
import
AttentionType
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -202,6 +203,10 @@ class TorchNativeAttnBackend(AttentionBackend):
...
@@ -202,6 +203,10 @@ class TorchNativeAttnBackend(AttentionBackend):
q_
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
)
q_
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
)
o_
=
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
)
o_
=
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
)
causal
=
True
if
layer
.
is_cross_attention
or
layer
.
attn_type
==
AttentionType
.
ENCODER_ONLY
:
causal
=
False
self
.
_run_sdpa_forward_extend
(
self
.
_run_sdpa_forward_extend
(
q_
,
q_
,
o_
,
o_
,
...
@@ -214,7 +219,7 @@ class TorchNativeAttnBackend(AttentionBackend):
...
@@ -214,7 +219,7 @@ class TorchNativeAttnBackend(AttentionBackend):
forward_batch
.
extend_seq_lens
,
forward_batch
.
extend_seq_lens
,
scaling
=
layer
.
scaling
,
scaling
=
layer
.
scaling
,
enable_gqa
=
use_gqa
,
enable_gqa
=
use_gqa
,
causal
=
not
layer
.
is_cross_attention
,
causal
=
causal
,
)
)
return
o
return
o
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
3bface15
...
@@ -10,6 +10,7 @@ import triton.language as tl
...
@@ -10,6 +10,7 @@ import triton.language as tl
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.radix_attention
import
AttentionType
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
get_bool_env_var
,
get_device_core_count
from
sglang.srt.utils
import
get_bool_env_var
,
get_device_core_count
...
@@ -528,6 +529,10 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -528,6 +529,10 @@ class TritonAttnBackend(AttentionBackend):
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
)
causal
=
True
if
layer
.
attn_type
==
AttentionType
.
ENCODER_ONLY
:
causal
=
False
self
.
extend_attention_fwd
(
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
k
.
contiguous
(),
...
@@ -539,6 +544,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -539,6 +544,7 @@ class TritonAttnBackend(AttentionBackend):
self
.
forward_metadata
.
kv_indptr
,
self
.
forward_metadata
.
kv_indptr
,
self
.
forward_metadata
.
kv_indices
,
self
.
forward_metadata
.
kv_indices
,
self
.
forward_metadata
.
custom_mask
,
self
.
forward_metadata
.
custom_mask
,
causal
,
self
.
forward_metadata
.
mask_indptr
,
self
.
forward_metadata
.
mask_indptr
,
self
.
forward_metadata
.
max_extend_len
,
self
.
forward_metadata
.
max_extend_len
,
layer
.
scaling
,
layer
.
scaling
,
...
...
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
3bface15
...
@@ -74,6 +74,7 @@ def _fwd_kernel(
...
@@ -74,6 +74,7 @@ def _fwd_kernel(
BLOCK_M
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
USE_CUSTOM_MASK
:
tl
.
constexpr
,
USE_CUSTOM_MASK
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
SKIP_PREFIX_CUSTOM_MASK
:
tl
.
constexpr
,
SKIP_PREFIX_CUSTOM_MASK
:
tl
.
constexpr
,
STORE_TRANSPOSE
:
tl
.
constexpr
,
STORE_TRANSPOSE
:
tl
.
constexpr
,
):
):
...
@@ -129,6 +130,7 @@ def _fwd_kernel(
...
@@ -129,6 +130,7 @@ def _fwd_kernel(
for
start_n
in
range
(
0
,
cur_seq_len_prefix
,
BLOCK_N
):
for
start_n
in
range
(
0
,
cur_seq_len_prefix
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
mask_n
=
(
start_n
+
offs_n
)
<
cur_seq_len_prefix
mask_n
=
(
start_n
+
offs_n
)
<
cur_seq_len_prefix
offs_kv_loc
=
tl
.
load
(
offs_kv_loc
=
tl
.
load
(
kv_indices
+
cur_seq_kv_start_idx
+
start_n
+
offs_n
,
mask
=
mask_n
,
other
=
0
kv_indices
+
cur_seq_kv_start_idx
+
start_n
+
offs_n
,
mask
=
mask_n
,
other
=
0
)
)
...
@@ -196,7 +198,11 @@ def _fwd_kernel(
...
@@ -196,7 +198,11 @@ def _fwd_kernel(
# stage 2: compute the triangle part
# stage 2: compute the triangle part
cur_block_m_end
=
tl
.
minimum
(
cur_seq_len_extend
,
(
cur_block_m
+
1
)
*
BLOCK_M
)
cur_block_m_end
=
(
cur_seq_len_extend
if
not
IS_CAUSAL
else
tl
.
minimum
(
cur_seq_len_extend
,
(
cur_block_m
+
1
)
*
BLOCK_M
)
)
for
start_n
in
range
(
0
,
cur_block_m_end
,
BLOCK_N
):
for
start_n
in
range
(
0
,
cur_block_m_end
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
mask_n
=
(
start_n
+
offs_n
)
<
cur_block_m_end
mask_n
=
(
start_n
+
offs_n
)
<
cur_block_m_end
...
@@ -243,12 +249,15 @@ def _fwd_kernel(
...
@@ -243,12 +249,15 @@ def _fwd_kernel(
)
)
custom_mask
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
custom_mask
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
qk
=
tl
.
where
(
custom_mask
,
qk
,
float
(
"-inf"
))
qk
=
tl
.
where
(
custom_mask
,
qk
,
float
(
"-inf"
))
el
se
:
el
if
IS_CAUSAL
:
mask_causual
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
>=
(
mask_causual
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
>=
(
start_n
+
offs_n
[
None
,
:]
start_n
+
offs_n
[
None
,
:]
)
)
mask_causual
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
mask_causual
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
qk
=
tl
.
where
(
mask_causual
,
qk
,
float
(
"-inf"
))
qk
=
tl
.
where
(
mask_causual
,
qk
,
float
(
"-inf"
))
else
:
mask_non_causal
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
qk
=
tl
.
where
(
mask_non_causal
,
qk
,
float
(
"-inf"
))
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
...
@@ -299,6 +308,7 @@ def extend_attention_fwd(
...
@@ -299,6 +308,7 @@ def extend_attention_fwd(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
custom_mask
,
custom_mask
,
is_causal
,
mask_indptr
,
mask_indptr
,
max_len_extend
,
max_len_extend
,
sm_scale
=
None
,
sm_scale
=
None
,
...
@@ -411,6 +421,7 @@ def extend_attention_fwd(
...
@@ -411,6 +421,7 @@ def extend_attention_fwd(
Lq
=
Lq
,
Lq
=
Lq
,
Lv
=
Lv
,
Lv
=
Lv
,
USE_CUSTOM_MASK
=
USE_CUSTOM_MASK
,
USE_CUSTOM_MASK
=
USE_CUSTOM_MASK
,
IS_CAUSAL
=
is_causal
,
SKIP_PREFIX_CUSTOM_MASK
=
SKIP_PREFIX_CUSTOM_MASK
,
SKIP_PREFIX_CUSTOM_MASK
=
SKIP_PREFIX_CUSTOM_MASK
,
STORE_TRANSPOSE
=
_is_hip
,
STORE_TRANSPOSE
=
_is_hip
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
...
...
python/sglang/srt/layers/radix_attention.py
View file @
3bface15
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# ==============================================================================
# ==============================================================================
"""Radix attention."""
"""Radix attention."""
from
enum
import
Enum
from
typing
import
Optional
from
typing
import
Optional
from
torch
import
nn
from
torch
import
nn
...
@@ -22,6 +23,18 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
...
@@ -22,6 +23,18 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
AttentionType
(
Enum
):
"""
Attention type.
Use string to be compatible with `torch.compile`.
"""
# Decoder attention between previous layer Q/K/V
DECODER
=
"decoder"
# Encoder attention between previous layer Q/K/V
ENCODER_ONLY
=
"encoder_only"
class
RadixAttention
(
nn
.
Module
):
class
RadixAttention
(
nn
.
Module
):
"""
"""
The attention layer implementation.
The attention layer implementation.
...
@@ -39,6 +52,7 @@ class RadixAttention(nn.Module):
...
@@ -39,6 +52,7 @@ class RadixAttention(nn.Module):
sliding_window_size
:
int
=
-
1
,
sliding_window_size
:
int
=
-
1
,
is_cross_attention
:
bool
=
False
,
is_cross_attention
:
bool
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
attn_type
=
AttentionType
.
DECODER
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
use_irope
:
bool
=
False
,
use_irope
:
bool
=
False
,
):
):
...
@@ -64,6 +78,7 @@ class RadixAttention(nn.Module):
...
@@ -64,6 +78,7 @@ class RadixAttention(nn.Module):
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
if
self
.
quant_method
is
not
None
:
if
self
.
quant_method
is
not
None
:
self
.
quant_method
.
create_weights
(
self
)
self
.
quant_method
.
create_weights
(
self
)
self
.
attn_type
=
attn_type
def
forward
(
def
forward
(
self
,
self
,
...
...
python/sglang/srt/models/bert.py
0 → 100644
View file @
3bface15
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
import
torch
from
torch
import
nn
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.activation
import
get_act_fn
from
sglang.srt.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.pooler
import
EmbeddingPoolerOutput
,
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
AttentionType
,
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
BertConfig
=
None
class
BertEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BertConfig
):
super
().
__init__
()
self
.
size
=
config
.
hidden_size
self
.
word_embeddings
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
position_embeddings
=
VocabParallelEmbedding
(
config
.
max_position_embeddings
,
config
.
hidden_size
)
self
.
token_type_embeddings
=
VocabParallelEmbedding
(
config
.
type_vocab_size
,
config
.
hidden_size
)
self
.
LayerNorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
position_ids
=
nn
.
Parameter
(
torch
.
empty
((
1
,
config
.
max_position_embeddings
)),
)
self
.
position_embedding_type
=
config
.
position_embedding_type
if
self
.
position_embedding_type
!=
"absolute"
:
raise
ValueError
(
"Only 'absolute' position_embedding_type"
+
" is supported"
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
input_shape
=
input_ids
.
size
()
# Input embeddings.
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
# Position embeddings.
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
token_type_ids
=
torch
.
zeros
(
input_shape
,
dtype
=
torch
.
long
,
device
=
inputs_embeds
.
device
)
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
embeddings
=
inputs_embeds
+
token_type_embeddings
+
position_embeddings
embeddings
=
self
.
LayerNorm
(
embeddings
)
return
embeddings
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BertConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
layer
=
nn
.
ModuleList
(
[
BertLayer
(
config
=
config
,
layer_id
=
layer_idx
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.layer.
{
layer_idx
}
"
,
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
]
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
)
->
torch
.
Tensor
:
for
layer
in
self
.
layer
:
hidden_states
=
layer
(
hidden_states
,
forward_batch
)
return
hidden_states
class
BertLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BertConfig
,
layer_id
:
int
=
0
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
attention
=
BertAttention
(
hidden_size
=
config
.
hidden_size
,
num_attention_heads
=
config
.
num_attention_heads
,
layer_id
=
layer_id
,
layer_norm_eps
=
config
.
layer_norm_eps
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attention"
,
)
self
.
intermediate
=
BertIntermediate
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.intermediate"
,
)
self
.
output
=
BertOutput
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
layer_norm_eps
=
config
.
layer_norm_eps
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.output"
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
):
attn_output
=
self
.
attention
(
hidden_states
,
forward_batch
)
intermediate_output
=
self
.
intermediate
(
attn_output
)
output
=
self
.
output
(
intermediate_output
,
attn_output
)
return
output
class
BertAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_attention_heads
:
int
,
layer_norm_eps
:
float
,
layer_id
:
int
=
0
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
self_attn
=
BertSelfAttention
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.output"
,
)
self
.
output
=
BertSelfOutput
(
hidden_size
=
hidden_size
,
layer_norm_eps
=
layer_norm_eps
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.output"
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
)
->
torch
.
Tensor
:
self_output
=
self
.
self_attn
(
hidden_states
,
forward_batch
)
return
self
.
output
(
self_output
,
hidden_states
)
class
BertSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_attention_heads
:
int
,
layer_id
:
int
=
0
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_attention_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
self
.
total_num_heads
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
assert
self
.
head_dim
*
self
.
total_num_heads
==
self
.
hidden_size
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
self
.
hidden_size
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
total_num_heads
,
total_num_kv_heads
=
self
.
total_num_kv_heads
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
attn
=
RadixAttention
(
num_heads
=
self
.
num_heads
,
head_dim
=
self
.
head_dim
,
scaling
=
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
AttentionType
.
ENCODER_ONLY
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
return
output
class
BertSelfOutput
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
layer_norm_eps
:
float
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
dense
=
RowParallelLinear
(
input_size
=
hidden_size
,
output_size
=
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.dense"
,
)
self
.
LayerNorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
layer_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
input_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
LayerNorm
(
hidden_states
+
input_tensor
)
return
hidden_states
class
BertIntermediate
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
dense
=
ColumnParallelLinear
(
input_size
=
hidden_size
,
output_size
=
intermediate_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.dense"
,
)
self
.
intermediate_act_fn
=
get_act_fn
(
hidden_act
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
intermediate_act_fn
(
hidden_states
)
return
hidden_states
class
BertOutput
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
layer_norm_eps
:
float
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
dense
=
RowParallelLinear
(
input_size
=
intermediate_size
,
output_size
=
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.dense"
,
)
self
.
LayerNorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
layer_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
input_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
LayerNorm
(
hidden_states
+
input_tensor
)
return
hidden_states
class
BertModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
config
:
BertConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
self
.
embeddings
=
BertEmbedding
(
config
)
self
.
encoder
=
BertEncoder
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
f
"encoder"
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
# self.pooler = BertPooler(config)
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
False
,
)
->
torch
.
Tensor
:
assert
get_embedding
==
True
# Your tokenized IDs
hidden_states
=
self
.
embeddings
(
input_ids
=
input_ids
,
position_ids
=
positions
,
)
hidden_states
=
self
.
encoder
(
hidden_states
,
forward_batch
=
forward_batch
)
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"query"
,
"q"
),
(
"qkv_proj"
,
"key"
,
"k"
),
(
"qkv_proj"
,
"value"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
name
=
name
.
replace
(
"self"
,
"self_attn"
)
if
"pooler"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
class
Contriever
(
BertModel
):
pass
EntryClass
=
[
BertModel
,
Contriever
]
python/sglang/test/runners.py
View file @
3bface15
...
@@ -51,6 +51,8 @@ NUM_TOP_LOGPROBS = 5
...
@@ -51,6 +51,8 @@ NUM_TOP_LOGPROBS = 5
def
get_dtype_str
(
torch_dtype
):
def
get_dtype_str
(
torch_dtype
):
if
torch_dtype
is
torch
.
float16
:
if
torch_dtype
is
torch
.
float16
:
return
"float16"
return
"float16"
if
torch_dtype
is
torch
.
float32
:
return
"float32"
else
:
else
:
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -447,6 +449,7 @@ class SRTRunner:
...
@@ -447,6 +449,7 @@ class SRTRunner:
port
:
int
=
DEFAULT_PORT_FOR_SRT_TEST_RUNNER
,
port
:
int
=
DEFAULT_PORT_FOR_SRT_TEST_RUNNER
,
lora_paths
:
List
[
str
]
=
None
,
lora_paths
:
List
[
str
]
=
None
,
max_loras_per_batch
:
int
=
4
,
max_loras_per_batch
:
int
=
4
,
attention_backend
:
Optional
[
str
]
=
None
,
lora_backend
:
str
=
"triton"
,
lora_backend
:
str
=
"triton"
,
disable_cuda_graph
:
bool
=
False
,
disable_cuda_graph
:
bool
=
False
,
disable_radix_cache
:
bool
=
False
,
disable_radix_cache
:
bool
=
False
,
...
@@ -487,6 +490,7 @@ class SRTRunner:
...
@@ -487,6 +490,7 @@ class SRTRunner:
lora_paths
=
lora_paths
,
lora_paths
=
lora_paths
,
max_loras_per_batch
=
max_loras_per_batch
,
max_loras_per_batch
=
max_loras_per_batch
,
lora_backend
=
lora_backend
,
lora_backend
=
lora_backend
,
attention_backend
=
attention_backend
,
disable_cuda_graph
=
disable_cuda_graph
,
disable_cuda_graph
=
disable_cuda_graph
,
disable_radix_cache
=
disable_radix_cache
,
disable_radix_cache
=
disable_radix_cache
,
chunked_prefill_size
=
chunked_prefill_size
,
chunked_prefill_size
=
chunked_prefill_size
,
...
...
test/srt/models/test_encoder_embedding_models.py
0 → 100644
View file @
3bface15
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# python -m unittest test_encoder_embedding_models.TestEncoderEmbeddingModels.test_prefill_logits
import
multiprocessing
as
mp
import
random
import
time
import
unittest
import
torch
from
transformers
import
AutoConfig
,
AutoTokenizer
from
sglang.test.runners
import
DEFAULT_PROMPTS
,
HFRunner
,
SRTRunner
from
sglang.test.test_utils
import
CustomTestCase
,
get_similarities
,
is_in_ci
MODELS
=
[(
"BAAI/bge-small-en"
,
1
,
1e-5
)]
ATTENTION_BACKEND
=
[
"torch_native"
,
"triton"
]
BATCH_SIZE
=
[
30
]
TORCH_DTYPES
=
[
torch
.
float32
]
sgl_to_st_ratio
=
[]
class
TestEncoderEmbeddingModels
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
def
_truncate_prompts
(
self
,
prompts
,
model_path
):
config
=
AutoConfig
.
from_pretrained
(
model_path
)
max_length
=
getattr
(
config
,
"max_position_embeddings"
,
512
)
-
20
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
)
truncated_prompts
=
[]
for
prompt
in
prompts
:
tokens
=
tokenizer
(
prompt
,
return_tensors
=
"pt"
,
truncation
=
False
)
if
len
(
tokens
.
input_ids
[
0
])
>
max_length
:
truncated_text
=
tokenizer
.
decode
(
tokens
.
input_ids
[
0
][:
max_length
-
1
],
skip_special_tokens
=
True
)
truncated_prompts
.
append
(
truncated_text
)
else
:
truncated_prompts
.
append
(
prompt
)
return
truncated_prompts
def
assert_close_prefill_logits
(
self
,
prompts
,
model_path
,
tp_size
,
torch_dtype
,
prefill_tolerance
,
attention_backend
,
batch_size
,
)
->
None
:
truncated_prompts
=
self
.
_truncate_prompts
(
prompts
,
model_path
)
truncated_prompts
=
truncated_prompts
*
batch_size
with
HFRunner
(
model_path
,
torch_dtype
=
torch_dtype
,
model_type
=
"embedding"
,
)
as
hf_runner
:
# warm up
hf_outputs
=
hf_runner
.
forward
(
truncated_prompts
)
st_start_time
=
time
.
time
()
hf_outputs
=
hf_runner
.
forward
(
truncated_prompts
)
st_end_time
=
time
.
time
()
with
SRTRunner
(
model_path
,
tp_size
=
tp_size
,
torch_dtype
=
torch_dtype
,
model_type
=
"embedding"
,
attention_backend
=
attention_backend
,
chunked_prefill_size
=-
1
,
disable_radix_cache
=
True
,
)
as
srt_runner
:
# warm up
srt_outputs
=
srt_runner
.
forward
(
truncated_prompts
)
sgl_start_time
=
time
.
time
()
srt_outputs
=
srt_runner
.
forward
(
truncated_prompts
)
sgl_end_time
=
time
.
time
()
transformer_time
=
st_end_time
-
st_start_time
sgl_time
=
sgl_end_time
-
sgl_start_time
sgl_to_st_ratio
.
append
(
sgl_time
/
transformer_time
)
for
i
in
range
(
len
(
truncated_prompts
)):
hf_logits
=
torch
.
Tensor
(
hf_outputs
.
embed_logits
[
i
])
srt_logits
=
torch
.
Tensor
(
srt_outputs
.
embed_logits
[
i
])
similarity
=
torch
.
tensor
(
get_similarities
(
hf_logits
,
srt_logits
))
# If something is wrong, uncomment this to observe similarity.
# print("similarity diff", abs(similarity - 1))
if
len
(
truncated_prompts
[
i
])
<=
1000
:
assert
torch
.
all
(
abs
(
similarity
-
1
)
<
prefill_tolerance
),
"embeddings are not all close"
def
test_prefill_logits
(
self
):
models_to_test
=
MODELS
if
is_in_ci
():
models_to_test
=
[
random
.
choice
(
MODELS
)]
for
model
,
tp_size
,
prefill_tolerance
in
models_to_test
:
for
attention_backend
in
ATTENTION_BACKEND
:
for
batch_size
in
BATCH_SIZE
:
for
torch_dtype
in
TORCH_DTYPES
:
self
.
assert_close_prefill_logits
(
DEFAULT_PROMPTS
,
model
,
tp_size
,
torch_dtype
,
prefill_tolerance
,
attention_backend
,
batch_size
,
)
for
i
in
range
(
len
(
BATCH_SIZE
)):
print
(
"bacth size: "
,
BATCH_SIZE
[
i
]
*
5
,
"sgl_time/st_time"
,
round
(
sgl_to_st_ratio
[
i
],
3
),
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_triton_attention_kernels.py
View file @
3bface15
...
@@ -116,6 +116,7 @@ class TestTritonAttention(CustomTestCase):
...
@@ -116,6 +116,7 @@ class TestTritonAttention(CustomTestCase):
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
custom_mask
,
custom_mask
,
True
,
mask_indptr
,
mask_indptr
,
max_len_extend
,
max_len_extend
,
)
)
...
@@ -150,6 +151,7 @@ class TestTritonAttention(CustomTestCase):
...
@@ -150,6 +151,7 @@ class TestTritonAttention(CustomTestCase):
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
custom_mask
,
custom_mask
,
True
,
mask_indptr
,
mask_indptr
,
max_len_extend
,
max_len_extend
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment