Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
30b4f771
Unverified
Commit
30b4f771
authored
Aug 26, 2024
by
Chayenne
Committed by
GitHub
Aug 25, 2024
Browse files
Support Alibaba-NLP/gte-Qwen2-7B-instruct embedding Model (#1186)
Co-authored-by:
Ying Sheng
<
sqy1415@gmail.com
>
parent
66e7dcaf
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
167 additions
and
55 deletions
+167
-55
.github/workflows/accuracy-test.yml
.github/workflows/accuracy-test.yml
+1
-1
.github/workflows/unit-test.yml
.github/workflows/unit-test.yml
+1
-1
README.md
README.md
+15
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+4
-1
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+1
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+13
-4
python/sglang/srt/models/llama_embedding.py
python/sglang/srt/models/llama_embedding.py
+4
-0
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+9
-3
python/sglang/srt/server.py
python/sglang/srt/server.py
+3
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+11
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+7
-2
python/sglang/test/runners.py
python/sglang/test/runners.py
+16
-16
test/srt/models/test_embedding_models.py
test/srt/models/test_embedding_models.py
+13
-15
test/srt/models/test_generation_models.py
test/srt/models/test_generation_models.py
+65
-8
test/srt/run_suite.py
test/srt/run_suite.py
+4
-4
No files found.
.github/workflows/accuracy-test.yml
View file @
30b4f771
...
...
@@ -43,4 +43,4 @@ jobs:
run
:
|
cd test/srt
python3 test_eval_accuracy_large.py
timeout-minutes
:
1
0
timeout-minutes
:
2
0
.github/workflows/unit-test.yml
View file @
30b4f771
...
...
@@ -41,7 +41,7 @@ jobs:
run
:
|
cd test/srt
python3 run_suite.py --suite minimal
timeout-minutes
:
18
timeout-minutes
:
20
-
name
:
Test Frontend Language
run
:
|
...
...
README.md
View file @
30b4f771
...
...
@@ -187,6 +187,13 @@ response = client.chat.completions.create(
max_tokens
=
64
,
)
print
(
response
)
# Text embedding
response
=
client
.
embeddings
.
create
(
model
=
"default"
,
input
=
"How are you today"
,
)
print
(
response
)
```
It supports streaming, vision, and most features of the Chat/Completions/Models/Batch endpoints specified by the
[
OpenAI API Reference
](
https://platform.openai.com/docs/api-reference/
)
.
...
...
@@ -223,6 +230,8 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
### Supported Models
**Generative Models**
-
Llama / Llama 2 / Llama 3 / Llama 3.1
-
Mistral / Mixtral / Mistral NeMo
-
Gemma / Gemma 2
...
...
@@ -243,6 +252,12 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
-
ChatGLM
-
InternLM 2
**Embedding Models**
-
e5-mistral
-
gte-Qwen2
-
`python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct --is-embedding`
Instructions for supporting a new model are
[
here
](
https://github.com/sgl-project/sglang/blob/main/docs/en/model_support.md
)
.
#### Use Models From ModelScope
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
30b4f771
...
...
@@ -94,7 +94,10 @@ class TokenizerManager:
trust_remote_code
=
server_args
.
trust_remote_code
,
model_overide_args
=
model_overide_args
,
)
self
.
is_generation
=
is_generation_model
(
self
.
hf_config
.
architectures
)
self
.
is_generation
=
is_generation_model
(
self
.
hf_config
.
architectures
,
self
.
server_args
.
is_embedding
)
if
server_args
.
context_length
is
not
None
:
self
.
context_len
=
server_args
.
context_length
...
...
python/sglang/srt/managers/tp_worker.py
View file @
30b4f771
...
...
@@ -94,6 +94,7 @@ class ModelTpServer:
context_length
=
server_args
.
context_length
,
model_overide_args
=
model_overide_args
,
)
self
.
model_runner
=
ModelRunner
(
model_config
=
self
.
model_config
,
mem_fraction_static
=
server_args
.
mem_fraction_static
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
30b4f771
...
...
@@ -204,7 +204,7 @@ class ModelRunner:
else
None
)
self
.
is_generation
=
is_generation_model
(
self
.
model_config
.
hf_config
.
architectures
self
.
model_config
.
hf_config
.
architectures
,
self
.
server_args
.
is_embedding
)
logger
.
info
(
...
...
@@ -522,9 +522,18 @@ class ModelRunner:
batch
,
forward_mode
=
ForwardMode
.
EXTEND
,
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
)
if
self
.
is_generation
:
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
)
else
:
# Only embedding models have get_embedding parameter
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
,
get_embedding
=
True
,
)
@
torch
.
inference_mode
()
def
forward_extend_multi_modal
(
self
,
batch
:
ScheduleBatch
):
...
...
python/sglang/srt/models/llama_embedding.py
View file @
30b4f771
...
...
@@ -29,7 +29,11 @@ class LlamaEmbeddingModel(nn.Module):
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
True
,
)
->
EmbeddingPoolerOutput
:
assert
(
get_embedding
),
"LlamaEmbeddingModel / MistralModel is only used for embedding"
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
pooler
(
hidden_states
,
input_metadata
)
...
...
python/sglang/srt/models/qwen2.py
View file @
30b4f771
...
...
@@ -38,6 +38,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.pooler
import
EmbeddingPoolerOutput
,
Pooler
,
PoolingType
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -275,6 +276,7 @@ class Qwen2ForCausalLM(nn.Module):
self
.
model
=
Qwen2Model
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -283,11 +285,15 @@ class Qwen2ForCausalLM(nn.Module):
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
False
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
if
not
get_embedding
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
else
:
return
self
.
pooler
(
hidden_states
,
input_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/server.py
View file @
30b4f771
...
...
@@ -333,11 +333,13 @@ def launch_server(
start_process
=
start_controller_process_single
else
:
start_process
=
start_controller_process_multi
proc_controller
=
mp
.
Process
(
target
=
start_process
,
args
=
(
server_args
,
port_args
,
pipe_controller_writer
,
model_overide_args
),
)
proc_controller
.
start
()
proc_detoken
=
mp
.
Process
(
target
=
start_detokenizer_process
,
args
=
(
...
...
@@ -515,6 +517,7 @@ class Runtime:
self
.
pid
=
None
pipe_reader
,
pipe_writer
=
mp
.
Pipe
(
duplex
=
False
)
proc
=
mp
.
Process
(
target
=
launch_server
,
args
=
(
self
.
server_args
,
model_overide_args
,
pipe_writer
),
...
...
python/sglang/srt/server_args.py
View file @
30b4f771
...
...
@@ -38,6 +38,7 @@ class ServerArgs:
quantization
:
Optional
[
str
]
=
None
served_model_name
:
Optional
[
str
]
=
None
chat_template
:
Optional
[
str
]
=
None
is_embedding
:
bool
=
False
# Port
host
:
str
=
"127.0.0.1"
...
...
@@ -200,6 +201,11 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"Whether or not to allow for custom models defined on the Hub in their own modeling files."
,
)
parser
.
add_argument
(
"--is-embedding"
,
action
=
"store_true"
,
help
=
"Whether to use a CausalLM as an embedding model."
,
)
parser
.
add_argument
(
"--context-length"
,
type
=
int
,
...
...
@@ -458,6 +464,11 @@ class ServerArgs:
assert
not
(
self
.
dp_size
>
1
and
self
.
node_rank
is
not
None
),
"multi-node data parallel is not supported"
if
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
==
self
.
model_path
:
logger
.
info
(
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
)
self
.
trust_remote_code
=
False
if
"gemma-2"
in
self
.
model_path
.
lower
():
logger
.
info
(
"When using sliding window in gemma-2, turn on flashinfer."
)
self
.
disable_flashinfer
=
False
...
...
python/sglang/srt/utils.py
View file @
30b4f771
...
...
@@ -224,13 +224,18 @@ def is_multimodal_model(model):
raise
ValueError
(
"unrecognized type"
)
def
is_generation_model
(
model_architectures
):
def
is_generation_model
(
model_architectures
,
is_embedding
:
bool
=
False
):
# We have two ways to determine whether a model is a generative model.
# 1. Check the model architectue
# 2. check the `is_embedding` server args
if
(
"LlamaEmbeddingModel"
in
model_architectures
or
"MistralModel"
in
model_architectures
):
return
False
return
True
else
:
return
not
is_embedding
def
decode_video_base64
(
video_base64
):
...
...
python/sglang/test/runners.py
View file @
30b4f771
...
...
@@ -14,7 +14,7 @@ limitations under the License.
"""
import
json
import
multiprocessing
import
multiprocessing
as
mp
import
os
from
dataclasses
import
dataclass
from
typing
import
List
,
Union
...
...
@@ -63,37 +63,35 @@ class HFRunner:
self
,
model_path
,
torch_dtype
,
is_generation
_model
,
is_generation
,
):
self
.
in_queue
=
multiprocessing
.
Queue
()
self
.
out_queue
=
multiprocessing
.
Queue
()
self
.
is_generation
=
is_generation
self
.
model_proc
=
multiprocessing
.
Process
(
self
.
in_queue
=
mp
.
Queue
()
self
.
out_queue
=
mp
.
Queue
()
self
.
model_proc
=
mp
.
Process
(
target
=
self
.
start_model_process
,
args
=
(
self
.
in_queue
,
self
.
out_queue
,
model_path
,
torch_dtype
,
is_generation_model
,
),
)
self
.
model_proc
.
start
()
def
start_model_process
(
self
,
in_queue
,
out_queue
,
model_path
,
torch_dtype
,
is_generation_model
):
def
start_model_process
(
self
,
in_queue
,
out_queue
,
model_path
,
torch_dtype
):
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
torch_dtype
=
torch_dtype
,
)
self
.
is_generation_model
=
is_generation_model
if
self
.
is_generation_model
:
if
self
.
is_generation
:
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
False
,
low_cpu_mem_usage
=
True
,
).
cuda
()
else
:
...
...
@@ -107,7 +105,7 @@ class HFRunner:
while
True
:
prompts
,
max_new_tokens
=
in_queue
.
get
()
if
prompts
is
not
None
:
if
self
.
is_generation
_model
:
if
self
.
is_generation
:
output_strs
=
[]
prefill_logprobs
=
[]
for
p
in
prompts
:
...
...
@@ -171,17 +169,19 @@ class SRTRunner:
self
,
model_path
,
torch_dtype
,
is_generation
_model
,
is_generation
,
tp_size
=
1
,
port
=
5157
,
):
self
.
is_generation
_model
=
is_generation
_model
self
.
is_generation
=
is_generation
self
.
runtime
=
Runtime
(
model_path
=
model_path
,
tp_size
=
tp_size
,
dtype
=
get_dtype_str
(
torch_dtype
),
port
=
port
,
mem_fraction_static
=
0.7
,
trust_remote_code
=
False
,
is_embedding
=
not
self
.
is_generation
,
)
def
forward
(
...
...
@@ -189,7 +189,7 @@ class SRTRunner:
prompts
:
Union
[
List
[
str
],
List
[
torch
.
Tensor
]]
=
DEFAULT_PROMPTS
,
max_new_tokens
=
8
,
):
if
self
.
is_generation
_model
:
if
self
.
is_generation
:
# the return value contains logprobs from prefill
output_strs
=
[]
top_input_logprobs
=
[]
...
...
test/srt/models/test_embedding_models.py
View file @
30b4f771
...
...
@@ -20,7 +20,10 @@ import torch
from
sglang.test.runners
import
DEFAULT_PROMPTS
,
HFRunner
,
SRTRunner
from
sglang.test.test_utils
import
get_similarities
MODELS
=
[(
"intfloat/e5-mistral-7b-instruct"
,
1
,
0.2
)]
MODELS
=
[
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
,
1
,
1e-5
),
(
"intfloat/e5-mistral-7b-instruct"
,
1
,
1e-5
),
]
TORCH_DTYPES
=
[
torch
.
float16
]
...
...
@@ -32,10 +35,10 @@ class TestEmbeddingModels(unittest.TestCase):
model_path
,
tp_size
,
torch_dtype
,
long_context
_tolerance
,
prefill
_tolerance
,
)
->
None
:
with
HFRunner
(
model_path
,
torch_dtype
=
torch_dtype
,
is_generation
_model
=
False
model_path
,
torch_dtype
=
torch_dtype
,
is_generation
=
False
)
as
hf_runner
:
hf_outputs
=
hf_runner
.
forward
(
prompts
)
...
...
@@ -43,11 +46,9 @@ class TestEmbeddingModels(unittest.TestCase):
model_path
,
tp_size
=
tp_size
,
torch_dtype
=
torch_dtype
,
is_generation
_model
=
False
,
is_generation
=
False
,
)
as
srt_runner
:
srt_outputs
=
srt_runner
.
forward
(
prompts
,
)
srt_outputs
=
srt_runner
.
forward
(
prompts
)
for
i
in
range
(
len
(
prompts
)):
hf_logits
=
torch
.
Tensor
(
hf_outputs
.
embed_logits
[
i
])
...
...
@@ -57,18 +58,15 @@ class TestEmbeddingModels(unittest.TestCase):
print
(
"similarity diff"
,
abs
(
similarity
-
1
))
if
len
(
prompts
[
i
])
<=
1000
:
tolerance
=
1e-5
else
:
tolerance
=
long_context_tolerance
assert
torch
.
all
(
abs
(
similarity
-
1
)
<
tolerance
),
"embeddings are not all close"
assert
torch
.
all
(
abs
(
similarity
-
1
)
<
prefill_tolerance
),
"embeddings are not all close"
def
test_prefill_logits
(
self
):
for
model
,
tp_size
,
long_context
_tolerance
in
MODELS
:
for
model
,
tp_size
,
prefill
_tolerance
in
MODELS
:
for
torch_dtype
in
TORCH_DTYPES
:
self
.
assert_close_prefill_logits
(
DEFAULT_PROMPTS
,
model
,
tp_size
,
torch_dtype
,
long_context
_tolerance
DEFAULT_PROMPTS
,
model
,
tp_size
,
torch_dtype
,
prefill
_tolerance
)
...
...
test/srt/models/test_generation_models.py
View file @
30b4f771
...
...
@@ -20,12 +20,46 @@ import torch
from
sglang.test.runners
import
DEFAULT_PROMPTS
,
HFRunner
,
SRTRunner
MODELS
=
[
(
"meta-llama/Meta-Llama-3.1-8B-Instruct"
,
1
,
1.1
),
(
"google/gemma-2-2b"
,
1
,
3
),
(
"meta-llama/Meta-Llama-3.1-8B-Instruct"
,
1
,
1.1
,
3e-2
,
1
),
(
"google/gemma-2-2b"
,
1
,
3
,
3e-2
,
1
),
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
,
1
,
None
,
6e-2
,
1
),
]
TORCH_DTYPES
=
[
torch
.
float16
]
def
lcs
(
X
,
Y
):
m
=
len
(
X
)
n
=
len
(
Y
)
L
=
[[
0
]
*
(
n
+
1
)
for
_
in
range
(
m
+
1
)]
for
i
in
range
(
m
+
1
):
for
j
in
range
(
n
+
1
):
if
i
==
0
or
j
==
0
:
L
[
i
][
j
]
=
0
elif
X
[
i
-
1
]
==
Y
[
j
-
1
]:
L
[
i
][
j
]
=
L
[
i
-
1
][
j
-
1
]
+
1
else
:
L
[
i
][
j
]
=
max
(
L
[
i
-
1
][
j
],
L
[
i
][
j
-
1
])
return
L
[
m
][
n
]
def
calculate_rouge_l
(
output_strs_list1
,
output_strs_list2
):
rouge_l_scores
=
[]
for
s1
,
s2
in
zip
(
output_strs_list1
,
output_strs_list2
):
lcs_len
=
lcs
(
s1
,
s2
)
precision
=
lcs_len
/
len
(
s1
)
if
len
(
s1
)
>
0
else
0
recall
=
lcs_len
/
len
(
s2
)
if
len
(
s2
)
>
0
else
0
if
precision
+
recall
>
0
:
fmeasure
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
else
:
fmeasure
=
0.0
rouge_l_scores
.
append
(
fmeasure
)
return
rouge_l_scores
class
TestGenerationModels
(
unittest
.
TestCase
):
def
assert_close_prefill_logits_and_output_strs
(
...
...
@@ -35,10 +69,14 @@ class TestGenerationModels(unittest.TestCase):
tp_size
,
torch_dtype
,
max_new_tokens
,
prefill_tolerance
,
rouge_threshold
,
long_context_tolerance
,
)
->
None
:
if
model_path
==
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
:
prompts
=
prompts
[:
-
1
]
with
HFRunner
(
model_path
,
torch_dtype
=
torch_dtype
,
is_generation
_model
=
True
model_path
,
torch_dtype
=
torch_dtype
,
is_generation
=
True
)
as
hf_runner
:
hf_outputs
=
hf_runner
.
forward
(
prompts
,
max_new_tokens
=
max_new_tokens
)
...
...
@@ -46,7 +84,7 @@ class TestGenerationModels(unittest.TestCase):
model_path
,
tp_size
=
tp_size
,
torch_dtype
=
torch_dtype
,
is_generation
_model
=
True
,
is_generation
=
True
,
)
as
srt_runner
:
srt_outputs
=
srt_runner
.
forward
(
prompts
,
max_new_tokens
=
max_new_tokens
)
...
...
@@ -56,17 +94,34 @@ class TestGenerationModels(unittest.TestCase):
print
(
"max_diff"
,
torch
.
max
(
abs
(
hf_logprobs
-
srt_logprobs
)))
if
hf_logprobs
.
shape
[
0
]
<=
100
:
tolerance
=
3e-2
assert
torch
.
all
(
abs
(
hf_logprobs
-
srt_logprobs
)
<
tolerance
abs
(
hf_logprobs
-
srt_logprobs
)
<
prefill_
tolerance
),
"prefill logprobs are not all close"
print
(
hf_outputs
.
output_strs
)
print
(
srt_outputs
.
output_strs
)
assert
hf_outputs
.
output_strs
==
srt_outputs
.
output_strs
rouge_l_scores
=
calculate_rouge_l
(
hf_outputs
.
output_strs
,
srt_outputs
.
output_strs
)
assert
all
(
score
>=
rouge_threshold
for
score
in
rouge_l_scores
),
f
"Not all ROUGE-L scores are greater than
{
rouge_threshold
}
"
def
test_prefill_logits_and_output_strs
(
self
):
for
model
,
tp_size
,
long_context_tolerance
in
MODELS
:
import
multiprocessing
as
mp
try
:
mp
.
set_start_method
(
"spawn"
)
except
RuntimeError
:
pass
for
(
model
,
tp_size
,
long_context_tolerance
,
prefill_tolerance
,
rouge_threshold
,
)
in
MODELS
:
for
torch_dtype
in
TORCH_DTYPES
:
max_new_tokens
=
8
self
.
assert_close_prefill_logits_and_output_strs
(
...
...
@@ -75,6 +130,8 @@ class TestGenerationModels(unittest.TestCase):
tp_size
,
torch_dtype
,
max_new_tokens
,
prefill_tolerance
=
prefill_tolerance
,
rouge_threshold
=
rouge_threshold
,
long_context_tolerance
=
long_context_tolerance
,
)
...
...
test/srt/run_suite.py
View file @
30b4f771
...
...
@@ -5,6 +5,9 @@ from sglang.test.test_utils import run_unittest_files
suites
=
{
"minimal"
:
[
"models/test_embedding_models.py"
,
"models/test_generation_models.py"
,
"sampling/penaltylib"
,
"test_chunked_prefill.py"
,
"test_embedding_openai_server.py"
,
"test_eval_accuracy_mini.py"
,
...
...
@@ -13,11 +16,8 @@ suites = {
"test_skip_tokenizer_init.py"
,
"test_torch_compile.py"
,
"test_triton_attn_backend.py"
,
"test_vision_openai_server.py"
,
"test_update_weights.py"
,
"models/test_generation_models.py"
,
"models/test_embedding_models.py"
,
"sampling/penaltylib"
,
"test_vision_openai_server.py"
,
],
"sampling/penaltylib"
:
glob
.
glob
(
"sampling/penaltylib/**/test_*.py"
,
recursive
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment