Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e040a245
Unverified
Commit
e040a245
authored
Aug 08, 2024
by
Ying Sheng
Committed by
GitHub
Aug 08, 2024
Browse files
Add e5-mistral embedding model - step 3/3 (#988)
parent
9f662501
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
474 additions
and
241 deletions
+474
-241
.github/workflows/unit-test.yml
.github/workflows/unit-test.yml
+1
-0
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+17
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+86
-49
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+178
-128
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+12
-2
python/sglang/srt/models/llama_embedding.py
python/sglang/srt/models/llama_embedding.py
+2
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+35
-3
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+9
-0
python/sglang/test/runners.py
python/sglang/test/runners.py
+52
-55
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+6
-0
test/srt/models/test_embedding_models.py
test/srt/models/test_embedding_models.py
+69
-0
test/srt/models/test_generation_models.py
test/srt/models/test_generation_models.py
+4
-2
test/srt/run_suite.py
test/srt/run_suite.py
+2
-1
No files found.
.github/workflows/unit-test.yml
View file @
e040a245
...
...
@@ -35,6 +35,7 @@ jobs:
pip install -e "python[all]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
pip install accelerate
pip install sentence_transformers
-
name
:
Test Frontend Language
run
:
|
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
e040a245
...
...
@@ -25,7 +25,11 @@ import zmq
import
zmq.asyncio
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.io_struct
import
BatchStrOut
,
BatchTokenIDOut
from
sglang.srt.managers.io_struct
import
(
BatchEmbeddingOut
,
BatchStrOut
,
BatchTokenIDOut
,
)
from
sglang.srt.managers.schedule_batch
import
FINISH_MATCHED_STR
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.utils
import
find_printable_text
,
get_exception_traceback
,
graceful_registry
...
...
@@ -66,6 +70,18 @@ class DetokenizerManager:
async
def
handle_loop
(
self
):
while
True
:
recv_obj
:
BatchTokenIDOut
=
await
self
.
recv_from_router
.
recv_pyobj
()
if
isinstance
(
recv_obj
,
BatchEmbeddingOut
):
self
.
send_to_tokenizer
.
send_pyobj
(
BatchEmbeddingOut
(
rids
=
recv_obj
.
rids
,
embeddings
=
recv_obj
.
embeddings
,
meta_info
=
recv_obj
.
meta_info
,
finished_reason
=
recv_obj
.
finished_reason
,
)
)
continue
assert
isinstance
(
recv_obj
,
BatchTokenIDOut
)
bs
=
len
(
recv_obj
.
rids
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
e040a245
...
...
@@ -143,6 +143,7 @@ class Req:
# Logprobs
self
.
return_logprob
=
False
self
.
embedding
=
None
self
.
logprob_start_len
=
0
self
.
top_logprobs_num
=
0
self
.
normalized_prompt_logprob
=
None
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
e040a245
...
...
@@ -21,7 +21,7 @@ import dataclasses
import
logging
import
multiprocessing
as
mp
import
os
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
,
Union
import
numpy
as
np
import
transformers
...
...
@@ -38,16 +38,19 @@ from sglang.srt.hf_transformers_utils import (
)
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchEmbeddingOut
,
BatchStrOut
,
BatchTokenIDOut
,
EmbeddingReqInput
,
FlushCacheReq
,
GenerateReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
)
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
is_multimodal_model
,
load_image
from
sglang.srt.utils
import
is_generation_model
,
is_multimodal_model
,
load_image
from
sglang.utils
import
get_exception_traceback
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
...
@@ -85,6 +88,7 @@ class TokenizerManager:
trust_remote_code
=
server_args
.
trust_remote_code
,
model_overide_args
=
model_overide_args
,
)
self
.
is_generation
=
is_generation_model
(
self
.
hf_config
.
architectures
)
if
server_args
.
context_length
is
not
None
:
self
.
context_len
=
server_args
.
context_length
...
...
@@ -133,7 +137,9 @@ class TokenizerManager:
image_data
,
aspect_ratio
,
grid_pinpoints
,
self
.
processor
)
async
def
generate_request
(
self
,
obj
:
GenerateReqInput
,
request
=
None
):
async
def
generate_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
=
None
):
if
self
.
to_create_loop
:
self
.
create_handle_loop
()
...
...
@@ -144,6 +150,8 @@ class TokenizerManager:
async
for
response
in
self
.
_handle_single_request
(
obj
,
request
):
yield
response
else
:
if
isinstance
(
obj
,
EmbeddingReqInput
):
raise
NotImplementedError
(
"Please send only one prompt in each request"
)
if
obj
.
stream
:
raise
ValueError
(
"Do not support stream for batch mode."
)
...
...
@@ -151,39 +159,47 @@ class TokenizerManager:
yield
response
async
def
_handle_single_request
(
self
,
obj
,
request
,
index
=
None
,
is_cache_for_prefill
=
False
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
,
index
=
None
,
is_cache_for_prefill
=
False
,
):
if
not
is_cache_for_prefill
:
# The normal case with a single prompt
not_use_index
=
index
is
None
rid
=
obj
.
rid
if
not_use_index
else
obj
.
rid
[
index
]
input_text
=
obj
.
text
if
not_use_index
else
obj
.
text
[
index
]
input_ids
=
(
self
.
tokenizer
.
encode
(
input_text
)
if
obj
.
input_ids
is
None
else
obj
.
input_ids
)
if
not
not_use_index
and
obj
.
input_ids
:
input_ids
=
obj
.
input_ids
[
index
]
if
obj
.
input_ids
is
None
:
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
else
:
input_ids
=
obj
.
input_ids
if
not_use_index
else
obj
.
input_ids
[
index
]
self
.
_validate_input_length
(
input_ids
)
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
if
not_use_index
else
obj
.
sampling_params
[
index
]
)
pixel_values
,
image_hash
,
image_size
=
await
self
.
_get_pixel_values
(
obj
.
image_data
if
not_use_index
else
obj
.
image_data
[
index
]
)
return_logprob
=
(
obj
.
return_logprob
if
not_use_index
else
obj
.
return_logprob
[
index
]
)
logprob_start_len
=
(
obj
.
logprob_start_len
if
not_use_index
else
obj
.
logprob_start_len
[
index
]
)
top_logprobs_num
=
(
obj
.
top_logprobs_num
if
not_use_index
else
obj
.
top_logprobs_num
[
index
]
)
if
self
.
is_generation
:
pixel_values
,
image_hash
,
image_size
=
await
self
.
_get_pixel_values
(
obj
.
image_data
if
not_use_index
else
obj
.
image_data
[
index
]
)
return_logprob
=
(
obj
.
return_logprob
if
not_use_index
else
obj
.
return_logprob
[
index
]
)
logprob_start_len
=
(
obj
.
logprob_start_len
if
not_use_index
else
obj
.
logprob_start_len
[
index
]
)
top_logprobs_num
=
(
obj
.
top_logprobs_num
if
not_use_index
else
obj
.
top_logprobs_num
[
index
]
)
else
:
# A prefill request to cache the common prompt for parallel sampling
assert
self
.
is_generation
if
obj
.
text
is
not
None
:
if
isinstance
(
obj
.
text
,
list
):
input_text
=
obj
.
text
[
index
]
...
...
@@ -213,19 +229,28 @@ class TokenizerManager:
logprob_start_len
=
obj
.
logprob_start_len
[
0
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
0
]
tokenized_obj
=
TokenizedGenerateReqInput
(
rid
,
input_text
,
input_ids
,
pixel_values
,
image_hash
,
image_size
,
sampling_params
,
return_logprob
,
logprob_start_len
,
top_logprobs_num
,
obj
.
stream
,
)
if
self
.
is_generation
:
tokenized_obj
=
TokenizedGenerateReqInput
(
rid
,
input_text
,
input_ids
,
pixel_values
,
image_hash
,
image_size
,
sampling_params
,
return_logprob
,
logprob_start_len
,
top_logprobs_num
,
obj
.
stream
,
)
else
:
# is embedding
tokenized_obj
=
TokenizedEmbeddingReqInput
(
rid
,
input_text
,
input_ids
,
sampling_params
,
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
event
=
asyncio
.
Event
()
...
...
@@ -368,7 +393,7 @@ class TokenizerManager:
self
,
event
:
asyncio
.
Event
,
state
:
ReqState
,
obj
:
GenerateReqInput
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
rid
:
str
,
request
,
):
...
...
@@ -381,12 +406,15 @@ class TokenizerManager:
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
continue
out
=
self
.
convert_logprob_style
(
state
.
out_list
[
-
1
],
obj
.
return_logprob
,
obj
.
top_logprobs_num
,
obj
.
return_text_in_logprobs
,
)
if
self
.
is_generation
:
out
=
self
.
convert_logprob_style
(
state
.
out_list
[
-
1
],
obj
.
return_logprob
,
obj
.
top_logprobs_num
,
obj
.
return_text_in_logprobs
,
)
else
:
# isinstance(obj, EmbeddingReqInput)
out
=
state
.
out_list
[
-
1
]
# Log requests
if
self
.
server_args
.
log_requests
and
state
.
finished
:
...
...
@@ -459,8 +487,10 @@ class TokenizerManager:
async
def
handle_loop
(
self
):
while
True
:
recv_obj
:
BatchStrOut
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
assert
isinstance
(
recv_obj
,
BatchStrOut
)
recv_obj
:
Union
[
BatchStrOut
,
BatchEmbeddingOut
]
=
(
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
)
assert
isinstance
(
recv_obj
,
(
BatchStrOut
,
BatchEmbeddingOut
))
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
state
=
self
.
rid_to_state
.
get
(
rid
,
None
)
...
...
@@ -468,10 +498,17 @@ class TokenizerManager:
continue
recv_obj
.
meta_info
[
i
][
"id"
]
=
rid
out_dict
=
{
"text"
:
recv_obj
.
output_strs
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
if
isinstance
(
recv_obj
,
BatchStrOut
):
out_dict
=
{
"text"
:
recv_obj
.
output_strs
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
else
:
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
out_dict
=
{
"embedding"
:
recv_obj
.
embeddings
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
state
.
out_list
.
append
(
out_dict
)
state
.
finished
=
recv_obj
.
finished_reason
[
i
]
is
not
None
state
.
event
.
set
()
...
...
python/sglang/srt/managers/tp_worker.py
View file @
e040a245
...
...
@@ -20,7 +20,7 @@ import multiprocessing
import
pickle
import
time
import
warnings
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Union
import
torch
import
torch.distributed
as
dist
...
...
@@ -31,8 +31,10 @@ from sglang.srt.constrained.jump_forward import JumpForwardCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchEmbeddingOut
,
BatchTokenIDOut
,
FlushCacheReq
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
)
from
sglang.srt.managers.policy_scheduler
import
PolicyScheduler
,
PrefillAdder
...
...
@@ -205,7 +207,9 @@ class ModelTpServer:
try
:
# Recv requests
for
recv_req
in
recv_reqs
:
if
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
if
isinstance
(
recv_req
,
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
)
):
self
.
handle_generate_request
(
recv_req
)
elif
isinstance
(
recv_req
,
FlushCacheReq
):
self
.
flush_cache
()
...
...
@@ -297,41 +301,42 @@ class ModelTpServer:
def
handle_generate_request
(
self
,
recv_req
:
TokenizedGenerateReqInput
,
recv_req
:
Union
[
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
],
):
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
)
req
.
pixel_values
=
recv_req
.
pixel_values
if
req
.
pixel_values
is
not
None
:
req
.
pad_value
=
[
(
recv_req
.
image_hash
)
%
self
.
model_config
.
vocab_size
,
(
recv_req
.
image_hash
>>
16
)
%
self
.
model_config
.
vocab_size
,
(
recv_req
.
image_hash
>>
32
)
%
self
.
model_config
.
vocab_size
,
(
recv_req
.
image_hash
>>
64
)
%
self
.
model_config
.
vocab_size
,
]
req
.
image_size
=
recv_req
.
image_size
(
req
.
origin_input_ids
,
req
.
image_offset
,
)
=
self
.
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids_unpadded
,
req
.
pad_value
,
req
.
pixel_values
.
shape
,
req
.
image_size
,
)
req
.
sampling_params
=
recv_req
.
sampling_params
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
logprob_start_len
=
recv_req
.
logprob_start_len
req
.
top_logprobs_num
=
recv_req
.
top_logprobs_num
req
.
stream
=
recv_req
.
stream
req
.
tokenizer
=
self
.
tokenizer
# Init regex fsm
if
req
.
sampling_params
.
regex
is
not
None
:
req
.
regex_fsm
=
self
.
regex_fsm_cache
.
query
(
req
.
sampling_params
.
regex
)
if
not
self
.
disable_regex_jump_forward
:
req
.
jump_forward_map
=
self
.
jump_forward_cache
.
query
(
req
.
sampling_params
.
regex
req
.
sampling_params
=
recv_req
.
sampling_params
if
self
.
model_runner
.
is_generation
:
req
.
pixel_values
=
recv_req
.
pixel_values
if
req
.
pixel_values
is
not
None
:
req
.
pad_value
=
[
(
recv_req
.
image_hash
)
%
self
.
model_config
.
vocab_size
,
(
recv_req
.
image_hash
>>
16
)
%
self
.
model_config
.
vocab_size
,
(
recv_req
.
image_hash
>>
32
)
%
self
.
model_config
.
vocab_size
,
(
recv_req
.
image_hash
>>
64
)
%
self
.
model_config
.
vocab_size
,
]
req
.
image_size
=
recv_req
.
image_size
(
req
.
origin_input_ids
,
req
.
image_offset
,
)
=
self
.
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids_unpadded
,
req
.
pad_value
,
req
.
pixel_values
.
shape
,
req
.
image_size
,
)
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
logprob_start_len
=
recv_req
.
logprob_start_len
req
.
top_logprobs_num
=
recv_req
.
top_logprobs_num
req
.
stream
=
recv_req
.
stream
# Init regex fsm
if
req
.
sampling_params
.
regex
is
not
None
:
req
.
regex_fsm
=
self
.
regex_fsm_cache
.
query
(
req
.
sampling_params
.
regex
)
if
not
self
.
disable_regex_jump_forward
:
req
.
jump_forward_map
=
self
.
jump_forward_cache
.
query
(
req
.
sampling_params
.
regex
)
# Truncate prompts that are too long
if
len
(
req
.
origin_input_ids
)
>=
self
.
max_req_input_len
:
...
...
@@ -340,14 +345,17 @@ class ModelTpServer:
"the max context length. Truncated!!!"
)
req
.
origin_input_ids
=
req
.
origin_input_ids
[:
self
.
max_req_input_len
]
req
.
sampling_params
.
max_new_tokens
=
min
(
(
req
.
sampling_params
.
max_new_tokens
if
req
.
sampling_params
.
max_new_tokens
is
not
None
else
1
<<
30
),
self
.
max_req_input_len
-
1
-
len
(
req
.
origin_input_ids
),
)
if
self
.
model_runner
.
is_generation
:
req
.
sampling_params
.
max_new_tokens
=
min
(
(
req
.
sampling_params
.
max_new_tokens
if
req
.
sampling_params
.
max_new_tokens
is
not
None
else
1
<<
30
),
self
.
max_req_input_len
-
1
-
len
(
req
.
origin_input_ids
),
)
self
.
waiting_queue
.
append
(
req
)
def
get_new_prefill_batch
(
self
)
->
Optional
[
ScheduleBatch
]:
...
...
@@ -439,47 +447,68 @@ class ModelTpServer:
self
.
model_config
.
vocab_size
,
self
.
int_token_logit_bias
)
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
output
.
next_token_logprobs
=
output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
].
tolist
()
output
.
input_token_logprobs
=
output
.
input_token_logprobs
.
tolist
()
output
.
normalized_prompt_logprobs
=
(
output
.
normalized_prompt_logprobs
.
tolist
()
)
if
self
.
model_runner
.
is_generation
:
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
output
.
next_token_logprobs
=
output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
].
tolist
()
output
.
input_token_logprobs
=
output
.
input_token_logprobs
.
tolist
()
output
.
normalized_prompt_logprobs
=
(
output
.
normalized_prompt_logprobs
.
tolist
()
)
next_token_ids
=
next_token_ids
.
tolist
()
else
:
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
next_token_ids
=
next_token_ids
.
tolist
()
else
:
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
# Check finish conditions
pt
=
0
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
is
not
self
.
current_inflight_req
:
# Inflight reqs' prefill is not finished
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_ids
[
i
])
req
.
check_finished
()
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
else
:
self
.
tree_cache
.
cache_unfinished_req
(
req
)
# Check finish conditions
pt
=
0
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
is
not
self
.
current_inflight_req
:
# Inflight reqs' prefill is not finished
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_ids
[
i
])
req
.
check_finished
()
if
req
is
self
.
current_inflight_req
:
# Inflight request would get a new req idx
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
else
:
self
.
tree_cache
.
cache_unfinished_req
(
req
)
if
req
.
return_logprob
:
self
.
add_logprob_return_values
(
i
,
req
,
pt
,
next_token_ids
,
output
)
pt
+=
req
.
extend_input_len
else
:
assert
batch
.
extend_num_tokens
!=
0
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
embeddings
=
output
.
embeddings
.
tolist
()
if
req
is
self
.
current_inflight_req
:
# Inflight request would get a new req idx
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
req
.
embedding
=
embeddings
[
i
]
if
req
is
not
self
.
current_inflight_req
:
# Inflight reqs' prefill is not finished
req
.
check_finished
()
if
req
.
return_logprob
:
self
.
add_logprob_return_values
(
i
,
req
,
pt
,
next_token_ids
,
output
)
pt
+=
req
.
extend_input_len
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
else
:
self
.
tree_cache
.
cache_unfinished_req
(
req
)
if
req
is
self
.
current_inflight_req
:
# Inflight request would get a new req idx
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
handle_finished_requests
(
batch
)
...
...
@@ -596,15 +625,19 @@ class ModelTpServer:
def
handle_finished_requests
(
self
,
batch
:
ScheduleBatch
):
output_rids
=
[]
output_vids
=
[]
decoded_texts
=
[]
output_read_ids
=
[]
output_read_offsets
=
[]
output_skip_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
output_meta_info
=
[]
output_finished_reason
:
List
[
BaseFinishReason
]
=
[]
if
self
.
model_runner
.
is_generation
:
output_vids
=
[]
decoded_texts
=
[]
output_read_ids
=
[]
output_read_offsets
=
[]
output_skip_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
else
:
# for embedding model
output_embeddings
=
[]
unfinished_indices
=
[]
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
not
req
.
finished
()
and
req
is
not
self
.
current_inflight_req
:
unfinished_indices
.
append
(
i
)
...
...
@@ -619,56 +652,73 @@ class ModelTpServer:
)
):
output_rids
.
append
(
req
.
rid
)
output_vids
.
append
(
req
.
vid
)
decoded_texts
.
append
(
req
.
decoded_text
)
read_ids
,
read_offset
=
req
.
init_incremental_detokenize
()
output_read_ids
.
append
(
read_ids
)
output_read_offsets
.
append
(
read_offset
)
output_skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
)
output_spaces_between_special_tokens
.
append
(
req
.
sampling_params
.
spaces_between_special_tokens
)
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
"completion_tokens"
:
len
(
req
.
output_ids
),
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
,
"finish_reason"
:
str
(
req
.
finished_reason
),
}
if
req
.
return_logprob
:
(
meta_info
[
"input_token_logprobs"
],
meta_info
[
"output_token_logprobs"
],
meta_info
[
"input_top_logprobs"
],
meta_info
[
"output_top_logprobs"
],
meta_info
[
"normalized_prompt_logprob"
],
)
=
(
req
.
input_token_logprobs
,
req
.
output_token_logprobs
,
req
.
input_top_logprobs
,
req
.
output_top_logprobs
,
req
.
normalized_prompt_logprob
,
)
output_meta_info
.
append
(
meta_info
)
output_finished_reason
.
append
(
req
.
finished_reason
)
if
self
.
model_runner
.
is_generation
:
output_vids
.
append
(
req
.
vid
)
decoded_texts
.
append
(
req
.
decoded_text
)
read_ids
,
read_offset
=
req
.
init_incremental_detokenize
()
output_read_ids
.
append
(
read_ids
)
output_read_offsets
.
append
(
read_offset
)
output_skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
)
output_spaces_between_special_tokens
.
append
(
req
.
sampling_params
.
spaces_between_special_tokens
)
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
"completion_tokens"
:
len
(
req
.
output_ids
),
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
,
"finish_reason"
:
str
(
req
.
finished_reason
),
}
if
req
.
return_logprob
:
(
meta_info
[
"input_token_logprobs"
],
meta_info
[
"output_token_logprobs"
],
meta_info
[
"input_top_logprobs"
],
meta_info
[
"output_top_logprobs"
],
meta_info
[
"normalized_prompt_logprob"
],
)
=
(
req
.
input_token_logprobs
,
req
.
output_token_logprobs
,
req
.
input_top_logprobs
,
req
.
output_top_logprobs
,
req
.
normalized_prompt_logprob
,
)
output_meta_info
.
append
(
meta_info
)
else
:
# for embedding model
output_embeddings
.
append
(
req
.
embedding
)
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
}
output_meta_info
.
append
(
meta_info
)
# Send to detokenizer
if
output_rids
:
self
.
out_pyobjs
.
append
(
BatchTokenIDOut
(
output_rids
,
output_vids
,
decoded_texts
,
output_read_ids
,
output_read_offsets
,
output_skip_special_tokens
,
output_spaces_between_special_tokens
,
output_meta_info
,
output_finished_reason
,
if
self
.
model_runner
.
is_generation
:
self
.
out_pyobjs
.
append
(
BatchTokenIDOut
(
output_rids
,
output_vids
,
decoded_texts
,
output_read_ids
,
output_read_offsets
,
output_skip_special_tokens
,
output_spaces_between_special_tokens
,
output_meta_info
,
output_finished_reason
,
)
)
else
:
# for embedding model
self
.
out_pyobjs
.
append
(
BatchEmbeddingOut
(
output_rids
,
output_embeddings
,
output_meta_info
,
output_finished_reason
,
)
)
)
# Remove finished reqs: update batch tensors
batch
.
filter_batch
(
unfinished_indices
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
e040a245
...
...
@@ -52,6 +52,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
get_available_gpu_memory
,
is_generation_model
,
is_llama3_405b_fp8
,
is_multimodal_model
,
monkey_patch_vllm_dummy_weight_loader
,
...
...
@@ -132,8 +133,10 @@ class ModelRunner:
self
.
init_cublas
()
self
.
init_flashinfer
()
# Capture cuda graphs
self
.
init_cuda_graphs
()
if
self
.
is_generation
:
# FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
# Capture cuda graphs
self
.
init_cuda_graphs
()
def
load_model
(
self
):
logger
.
info
(
...
...
@@ -184,6 +187,10 @@ class ModelRunner:
scheduler_config
=
None
,
cache_config
=
None
,
)
self
.
is_generation
=
is_generation_model
(
self
.
model_config
.
hf_config
.
architectures
)
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
] Load weight end. "
f
"type=
{
type
(
self
.
model
).
__name__
}
, "
...
...
@@ -406,8 +413,10 @@ def import_model_classes():
entry
,
list
):
# To support multiple model classes in one module
for
tmp
in
entry
:
assert
tmp
.
__name__
not
in
model_arch_name_to_cls
model_arch_name_to_cls
[
tmp
.
__name__
]
=
tmp
else
:
assert
entry
.
__name__
not
in
model_arch_name_to_cls
model_arch_name_to_cls
[
entry
.
__name__
]
=
entry
# compat: some models such as chatglm has incorrect class set in config.json
...
...
@@ -417,6 +426,7 @@ def import_model_classes():
):
for
remap
in
module
.
EntryClassRemapping
:
if
isinstance
(
remap
,
tuple
)
and
len
(
remap
)
==
2
:
assert
remap
[
0
]
not
in
model_arch_name_to_cls
model_arch_name_to_cls
[
remap
[
0
]]
=
remap
[
1
]
return
model_arch_name_to_cls
...
...
python/sglang/srt/models/llama_embedding.py
View file @
e040a245
...
...
@@ -84,3 +84,5 @@ class LlamaEmbeddingModel(nn.Module):
EntryClass
=
LlamaEmbeddingModel
# compat: e5-mistral model.config class == MistralModel
EntryClassRemapping
=
[(
"MistralModel"
,
LlamaEmbeddingModel
)]
python/sglang/srt/server.py
View file @
e040a245
...
...
@@ -52,7 +52,7 @@ from sglang.srt.managers.controller_single import (
start_controller_process
as
start_controller_process_single
,
)
from
sglang.srt.managers.detokenizer_manager
import
start_detokenizer_process
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.managers.io_struct
import
EmbeddingReqInput
,
GenerateReqInput
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.openai_api.adapter
import
(
load_chat_template_for_openai_api
,
...
...
@@ -97,6 +97,7 @@ async def health() -> Response:
async
def
get_model_info
():
result
=
{
"model_path"
:
tokenizer_manager
.
model_path
,
"is_generation"
:
tokenizer_manager
.
is_generation
,
}
return
result
...
...
@@ -148,6 +149,21 @@ app.post("/generate")(generate_request)
app
.
put
(
"/generate"
)(
generate_request
)
async
def
encode_request
(
obj
:
EmbeddingReqInput
,
request
:
Request
):
"""Handle an embedding request."""
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
return
ret
except
ValueError
as
e
:
return
JSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
app
.
post
(
"/encode"
)(
encode_request
)
app
.
put
(
"/encode"
)(
encode_request
)
@
app
.
post
(
"/v1/completions"
)
async
def
openai_v1_completions
(
raw_request
:
Request
):
return
await
v1_completions
(
tokenizer_manager
,
raw_request
)
...
...
@@ -380,6 +396,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
except
(
AssertionError
,
requests
.
exceptions
.
RequestException
)
as
e
:
last_traceback
=
get_exception_traceback
()
pass
model_info
=
res
.
json
()
if
not
success
:
if
pipe_finish_writer
is
not
None
:
...
...
@@ -388,15 +405,17 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
sys
.
exit
(
1
)
# Send a warmup request
request_name
=
"/generate"
if
model_info
[
"is_generation"
]
else
"/encode"
max_new_tokens
=
8
if
model_info
[
"is_generation"
]
else
0
try
:
for
_
in
range
(
server_args
.
dp_size
):
res
=
requests
.
post
(
url
+
"/generate"
,
url
+
request_name
,
json
=
{
"text"
:
"The capital city of France is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
8
,
"max_new_tokens"
:
max_new_tokens
,
},
},
headers
=
headers
,
...
...
@@ -529,5 +548,18 @@ class Runtime:
)
return
json
.
dumps
(
response
.
json
())
def
encode
(
self
,
prompt
:
str
,
):
json_data
=
{
"text"
:
prompt
,
}
response
=
requests
.
post
(
self
.
url
+
"/encode"
,
json
=
json_data
,
)
return
json
.
dumps
(
response
.
json
())
def
__del__
(
self
):
self
.
shutdown
()
python/sglang/srt/utils.py
View file @
e040a245
...
...
@@ -223,6 +223,15 @@ def is_multimodal_model(model):
raise
ValueError
(
"unrecognized type"
)
def
is_generation_model
(
model_architectures
):
if
(
"LlamaEmbeddingModel"
in
model_architectures
or
"MistralModel"
in
model_architectures
):
return
False
return
True
def
decode_video_base64
(
video_base64
):
from
PIL
import
Image
...
...
python/sglang/test/runners.py
View file @
e040a245
...
...
@@ -23,6 +23,7 @@ import torch.nn.functional as F
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
sglang.srt.server
import
Runtime
from
sglang.srt.utils
import
is_generation_model
DEFAULT_PROMPTS
=
[
"The capital of France is"
,
...
...
@@ -33,13 +34,6 @@ DEFAULT_PROMPTS = [
NUM_TOP_LOGPROBS
=
5
def
is_embedding_model
(
model_path
):
# FIXME incomplete list
if
"e5-mistral-7b-instruct"
in
model_path
.
lower
():
return
True
return
False
def
get_dtype_str
(
torch_dtype
):
if
torch_dtype
is
torch
.
float16
:
return
"float16"
...
...
@@ -60,7 +54,7 @@ class HFRunner:
self
,
model_path
,
torch_dtype
=
torch
.
float16
,
is_
embedding
_model
=
None
,
is_
generation
_model
=
None
,
):
self
.
in_queue
=
multiprocessing
.
Queue
()
self
.
out_queue
=
multiprocessing
.
Queue
()
...
...
@@ -72,13 +66,13 @@ class HFRunner:
self
.
out_queue
,
model_path
,
torch_dtype
,
is_
embedding
_model
,
is_
generation
_model
,
),
)
self
.
model_proc
.
start
()
def
start_model_process
(
self
,
in_queue
,
out_queue
,
model_path
,
torch_dtype
,
is_
embedding
_model
self
,
in_queue
,
out_queue
,
model_path
,
torch_dtype
,
is_
generation
_model
):
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
...
...
@@ -86,12 +80,12 @@ class HFRunner:
trust_remote_code
=
True
,
)
self
.
is_
embedding
_model
=
(
is_
embedding
_model
(
model_path
)
if
is_
embedding
_model
is
None
else
is_
embedding
_model
self
.
is_
generation
_model
=
(
is_
generation
_model
(
model_path
)
if
is_
generation
_model
is
None
else
is_
generation
_model
)
if
not
self
.
is_
embedding
_model
:
if
self
.
is_
generation
_model
:
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
torch_dtype
=
torch_dtype
,
...
...
@@ -103,13 +97,13 @@ class HFRunner:
self
.
model
=
SentenceTransformer
(
model_path
,
device
=
"cpu"
,
)
.
to
(
dtype
=
torch_dtype
)
model_kwargs
=
{
"torch_dtype"
:
torch_dtype
}
,
)
while
True
:
prompts
,
max_new_tokens
=
in_queue
.
get
()
if
prompts
is
not
None
:
if
not
self
.
is_
embedding
_model
:
if
self
.
is_
generation
_model
:
output_strs
=
[]
prefill_logprobs
=
[]
for
p
in
prompts
:
...
...
@@ -144,7 +138,6 @@ class HFRunner:
)
else
:
assert
isinstance
(
prompts
,
List
[
str
])
logits
=
self
.
model
.
encode
(
prompts
).
tolist
()
out_queue
.
put
(
ModelOutput
(
embed_logits
=
logits
))
...
...
@@ -175,16 +168,13 @@ class SRTRunner:
model_path
,
tp_size
=
1
,
torch_dtype
=
torch
.
float16
,
is_
embedding
_model
=
None
,
is_
generation
_model
=
None
,
):
self
.
is_
embedding
_model
=
(
is_
embedding
_model
(
model_path
)
if
is_
embedding
_model
is
None
else
is_
embedding
_model
self
.
is_
generation
_model
=
(
is_
generation
_model
(
model_path
)
if
is_
generation
_model
is
None
else
is_
generation
_model
)
if
self
.
is_embedding_model
:
raise
NotImplementedError
()
self
.
runtime
=
Runtime
(
model_path
=
model_path
,
tp_size
=
tp_size
,
...
...
@@ -196,38 +186,45 @@ class SRTRunner:
prompts
:
Union
[
List
[
str
],
List
[
torch
.
Tensor
]]
=
DEFAULT_PROMPTS
,
max_new_tokens
=
64
,
):
# the return value contains logprobs from prefill
output_strs
=
[]
top_input_logprobs
=
[]
sampling_params
=
{
"max_new_tokens"
:
max_new_tokens
,
"temperature"
:
0
}
for
prompt
in
prompts
:
response
=
self
.
runtime
.
generate
(
prompt
,
sampling_params
=
sampling_params
,
return_logprob
=
True
,
top_logprobs_num
=
NUM_TOP_LOGPROBS
,
)
response
=
json
.
loads
(
response
)
output_strs
.
append
(
response
[
"text"
])
top_input_logprobs
.
append
(
[
[
tup
[
0
]
for
tup
in
x
[:
NUM_TOP_LOGPROBS
]]
for
x
in
response
[
"meta_info"
][
"input_top_logprobs"
][
1
:]
]
+
[
if
self
.
is_generation_model
:
# the return value contains logprobs from prefill
output_strs
=
[]
top_input_logprobs
=
[]
sampling_params
=
{
"max_new_tokens"
:
max_new_tokens
,
"temperature"
:
0
}
for
prompt
in
prompts
:
response
=
self
.
runtime
.
generate
(
prompt
,
sampling_params
=
sampling_params
,
return_logprob
=
True
,
top_logprobs_num
=
NUM_TOP_LOGPROBS
,
)
response
=
json
.
loads
(
response
)
output_strs
.
append
(
response
[
"text"
])
top_input_logprobs
.
append
(
[
tup
[
0
]
for
tup
in
response
[
"meta_info"
][
"output_top_logprobs"
][
0
][
:
NUM_TOP_LOGPROBS
[
tup
[
0
]
for
tup
in
x
[:
NUM_TOP_LOGPROBS
]]
for
x
in
response
[
"meta_info"
][
"input_top_logprobs"
][
1
:]
]
+
[
[
tup
[
0
]
for
tup
in
response
[
"meta_info"
][
"output_top_logprobs"
][
0
][
:
NUM_TOP_LOGPROBS
]
]
]
]
)
# print(response["meta_info"]["output_top_logprobs"][0])
)
return
ModelOutput
(
output_strs
=
output_strs
,
top_input_logprobs
=
top_input_logprobs
)
return
ModelOutput
(
output_strs
=
output_strs
,
top_input_logprobs
=
top_input_logprobs
)
else
:
logits
=
[]
for
prompt
in
prompts
:
response
=
self
.
runtime
.
encode
(
prompt
)
response
=
json
.
loads
(
response
)
logits
.
append
(
response
[
"embedding"
])
return
ModelOutput
(
embed_logits
=
logits
)
def
__enter__
(
self
):
return
self
...
...
python/sglang/test/test_utils.py
View file @
e040a245
...
...
@@ -12,6 +12,8 @@ from typing import Callable, List, Optional
import
numpy
as
np
import
requests
import
torch
import
torch.nn.functional
as
F
from
sglang.global_config
import
global_config
from
sglang.lang.backend.openai
import
OpenAI
...
...
@@ -492,3 +494,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
print
(
f
"Fail. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s"
)
return
0
if
success
else
-
1
def
get_similarities
(
vec1
,
vec2
):
return
F
.
cosine_similarity
(
torch
.
tensor
(
vec1
),
torch
.
tensor
(
vec2
),
dim
=
0
)
test/srt/models/test_embedding_models.py
0 → 100644
View file @
e040a245
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import
unittest
import
torch
from
sglang.test.runners
import
DEFAULT_PROMPTS
,
HFRunner
,
SRTRunner
from
sglang.test.test_utils
import
get_similarities
MODELS
=
[(
"intfloat/e5-mistral-7b-instruct"
,
1
)]
TORCH_DTYPES
=
[
torch
.
float16
]
class
TestEmbeddingModels
(
unittest
.
TestCase
):
def
assert_close_prefill_logits
(
self
,
prompts
,
model_path
,
tp_size
,
torch_dtype
,
)
->
None
:
with
HFRunner
(
model_path
,
torch_dtype
=
torch_dtype
,
is_generation_model
=
False
)
as
hf_runner
:
hf_outputs
=
hf_runner
.
forward
(
prompts
)
with
SRTRunner
(
model_path
,
tp_size
=
tp_size
,
torch_dtype
=
torch_dtype
,
is_generation_model
=
False
,
)
as
srt_runner
:
srt_outputs
=
srt_runner
.
forward
(
prompts
)
for
i
in
range
(
len
(
prompts
)):
hf_logits
=
torch
.
Tensor
(
hf_outputs
.
embed_logits
[
i
])
srt_logits
=
torch
.
Tensor
(
srt_outputs
.
embed_logits
[
i
])
similarities
=
torch
.
tensor
(
get_similarities
(
hf_logits
,
srt_logits
))
tolerance
=
1e-2
assert
torch
.
all
(
abs
(
similarities
-
1
)
<
tolerance
),
f
"embeddings not all close"
def
test_prefill_logits
(
self
):
for
model
,
tp_size
in
MODELS
:
for
torch_dtype
in
TORCH_DTYPES
:
self
.
assert_close_prefill_logits
(
DEFAULT_PROMPTS
,
model
,
tp_size
,
torch_dtype
)
if
__name__
==
"__main__"
:
unittest
.
main
(
warnings
=
"ignore"
)
test/srt/models/test_
causal
_models.py
→
test/srt/models/test_
generation
_models.py
View file @
e040a245
...
...
@@ -3,7 +3,9 @@ Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -33,7 +35,7 @@ class TestCausalModels(unittest.TestCase):
torch_dtype
,
)
->
None
:
with
HFRunner
(
model_path
,
torch_dtype
=
torch_dtype
,
is_
embedding
_model
=
Fals
e
model_path
,
torch_dtype
=
torch_dtype
,
is_
generation
_model
=
Tru
e
)
as
hf_runner
:
hf_outputs
=
hf_runner
.
forward
(
prompts
)
...
...
@@ -41,7 +43,7 @@ class TestCausalModels(unittest.TestCase):
model_path
,
tp_size
=
tp_size
,
torch_dtype
=
torch_dtype
,
is_
embedding
_model
=
Fals
e
,
is_
generation
_model
=
Tru
e
,
)
as
srt_runner
:
srt_outputs
=
srt_runner
.
forward
(
prompts
)
...
...
test/srt/run_suite.py
View file @
e040a245
...
...
@@ -10,7 +10,8 @@ suites = {
"test_vision_openai_server.py"
,
"test_chunked_prefill.py"
,
"test_torch_compile.py"
,
"models/test_causal_models.py"
,
"models/test_generation_models.py"
,
"models/test_embedding_models.py"
,
"sampling/penaltylib"
,
],
"sampling/penaltylib"
:
glob
.
glob
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment