Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d8fbc7c0
Unverified
Commit
d8fbc7c0
authored
Apr 27, 2025
by
DavidBao
Committed by
GitHub
Apr 26, 2025
Browse files
[feature] support for roberta embedding models (#5730)
parent
c5e1026f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
186 additions
and
2 deletions
+186
-2
python/sglang/srt/layers/pooler.py
python/sglang/srt/layers/pooler.py
+6
-0
python/sglang/srt/models/roberta.py
python/sglang/srt/models/roberta.py
+178
-0
test/srt/models/test_encoder_embedding_models.py
test/srt/models/test_encoder_embedding_models.py
+2
-2
No files found.
python/sglang/srt/layers/pooler.py
View file @
d8fbc7c0
...
...
@@ -12,6 +12,7 @@ from sglang.srt.model_executor.model_runner import ForwardBatch
class
PoolingType
(
IntEnum
):
LAST
=
0
CLS
=
1
@
dataclass
...
...
@@ -41,6 +42,11 @@ class Pooler(nn.Module):
if
self
.
pooling_type
==
PoolingType
.
LAST
:
last_token_indices
=
torch
.
cumsum
(
forward_batch
.
extend_seq_lens
,
dim
=
0
)
-
1
pooled_data
=
hidden_states
[
last_token_indices
]
elif
self
.
pooling_type
==
PoolingType
.
CLS
:
prompt_lens
=
forward_batch
.
extend_seq_lens
first_token_flat_indices
=
torch
.
zeros_like
(
prompt_lens
)
first_token_flat_indices
[
1
:]
+=
torch
.
cumsum
(
prompt_lens
,
dim
=
0
)[:
-
1
]
pooled_data
=
hidden_states
[
first_token_flat_indices
]
else
:
raise
ValueError
(
f
"Invalid pooling type:
{
self
.
pooling_type
}
"
)
...
...
python/sglang/srt/models/roberta.py
0 → 100644
View file @
d8fbc7c0
# SPDX-License-Identifier: Apache-2.0
import
itertools
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.bert
import
BertEncoder
RobertaConfig
=
None
class
RobertaEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
:
RobertaConfig
):
super
().
__init__
()
self
.
size
=
config
.
hidden_size
self
.
word_embeddings
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
padding_idx
=
config
.
pad_token_id
self
.
position_embeddings
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
config
.
hidden_size
,
padding_idx
=
self
.
padding_idx
,
)
self
.
token_type_embeddings
=
nn
.
Embedding
(
config
.
type_vocab_size
,
config
.
hidden_size
)
self
.
LayerNorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
position_ids
=
nn
.
Parameter
(
torch
.
empty
((
1
,
config
.
max_position_embeddings
)),
)
self
.
position_embedding_type
=
config
.
position_embedding_type
if
self
.
position_embedding_type
!=
"absolute"
:
raise
ValueError
(
"Only 'absolute' position_embedding_type"
+
" is supported"
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
inputs_embeds
=
None
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
input_shape
=
input_ids
.
size
()
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
# adpated from vllm: https://github.com/vllm-project/vllm/commit/4a18fd14ba4a349291c798a16bf62fa8a9af0b6b/vllm/model_executor/models/roberta.py
pos_list
=
[]
token_list
=
[]
offset
=
0
for
seq_len
in
seq_lens
:
pos_list
.
append
(
position_ids
[
offset
:
offset
+
seq_len
])
token_list
.
append
(
input_ids
[
offset
:
offset
+
seq_len
])
offset
+=
seq_len
new_pos_list
=
[]
for
positions
,
tokens
in
zip
(
pos_list
,
token_list
):
# Verify assumption that incoming position are
# always a sequence from 0 to N.
expected_pos
=
torch
.
arange
(
positions
.
size
()[
0
],
dtype
=
torch
.
long
,
device
=
inputs_embeds
.
device
)
assert
torch
.
equal
(
positions
,
expected_pos
)
new_pos_list
.
append
(
create_position_ids_from_input_ids
(
tokens
,
self
.
padding_idx
)
)
position_ids
=
torch
.
cat
(
new_pos_list
)
# Position embeddings.
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros
(
input_shape
,
dtype
=
torch
.
long
,
device
=
inputs_embeds
.
device
)
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
embeddings
=
inputs_embeds
+
token_type_embeddings
+
position_embeddings
embeddings
=
self
.
LayerNorm
(
embeddings
)
return
embeddings
class
XLMRobertaModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
config
:
RobertaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
self
.
embeddings
=
RobertaEmbedding
(
config
)
self
.
encoder
=
BertEncoder
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
""
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
CLS
,
normalize
=
True
)
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
False
,
)
->
torch
.
Tensor
:
assert
get_embedding
==
True
# Your tokenized IDs
hidden_states
=
self
.
embeddings
(
input_ids
=
input_ids
,
position_ids
=
positions
,
seq_lens
=
forward_batch
.
seq_lens
,
)
hidden_states
=
self
.
encoder
(
hidden_states
,
forward_batch
=
forward_batch
)
pooler_out
=
self
.
pooler
(
hidden_states
,
forward_batch
)
return
pooler_out
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"query"
,
"q"
),
(
"qkv_proj"
,
"key"
,
"k"
),
(
"qkv_proj"
,
"value"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
name
=
name
.
replace
(
"self"
,
"self_attn"
)
if
"pooler"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# Adapted from transformers
def
create_position_ids_from_input_ids
(
input_ids
,
padding_idx
,
past_key_values_length
=
0
):
mask
=
input_ids
.
ne
(
padding_idx
).
int
()
incremental_indices
=
(
torch
.
cumsum
(
mask
,
dim
=
0
).
type_as
(
mask
)
+
past_key_values_length
)
*
mask
return
incremental_indices
.
long
()
+
padding_idx
EntryClass
=
[
XLMRobertaModel
]
test/srt/models/test_encoder_embedding_models.py
View file @
d8fbc7c0
...
...
@@ -25,10 +25,10 @@ from transformers import AutoConfig, AutoTokenizer
from
sglang.test.runners
import
DEFAULT_PROMPTS
,
HFRunner
,
SRTRunner
from
sglang.test.test_utils
import
CustomTestCase
,
get_similarities
,
is_in_ci
MODELS
=
[(
"BAAI/bge-small-en"
,
1
,
1e-5
)]
MODELS
=
[(
"BAAI/bge-small-en"
,
1
,
1e-5
),
(
"BAAI/bge-m3"
,
1
,
1e-5
)]
ATTENTION_BACKEND
=
[
"torch_native"
,
"triton"
]
BATCH_SIZE
=
[
30
]
BATCH_SIZE
=
[
1
,
2
]
TORCH_DTYPES
=
[
torch
.
float32
]
sgl_to_st_ratio
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment