Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e040a245
Unverified
Commit
e040a245
authored
Aug 08, 2024
by
Ying Sheng
Committed by
GitHub
Aug 08, 2024
Browse files
Add e5-mistral embedding model - step 3/3 (#988)
parent
9f662501
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
474 additions
and
241 deletions
+474
-241
.github/workflows/unit-test.yml
.github/workflows/unit-test.yml
+1
-0
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+17
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+86
-49
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+178
-128
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+12
-2
python/sglang/srt/models/llama_embedding.py
python/sglang/srt/models/llama_embedding.py
+2
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+35
-3
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+9
-0
python/sglang/test/runners.py
python/sglang/test/runners.py
+52
-55
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+6
-0
test/srt/models/test_embedding_models.py
test/srt/models/test_embedding_models.py
+69
-0
test/srt/models/test_generation_models.py
test/srt/models/test_generation_models.py
+4
-2
test/srt/run_suite.py
test/srt/run_suite.py
+2
-1
No files found.
.github/workflows/unit-test.yml
View file @
e040a245
...
@@ -35,6 +35,7 @@ jobs:
...
@@ -35,6 +35,7 @@ jobs:
pip install -e "python[all]"
pip install -e "python[all]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
pip install accelerate
pip install accelerate
pip install sentence_transformers
-
name
:
Test Frontend Language
-
name
:
Test Frontend Language
run
:
|
run
:
|
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
e040a245
...
@@ -25,7 +25,11 @@ import zmq
...
@@ -25,7 +25,11 @@ import zmq
import
zmq.asyncio
import
zmq.asyncio
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.io_struct
import
BatchStrOut
,
BatchTokenIDOut
from
sglang.srt.managers.io_struct
import
(
BatchEmbeddingOut
,
BatchStrOut
,
BatchTokenIDOut
,
)
from
sglang.srt.managers.schedule_batch
import
FINISH_MATCHED_STR
from
sglang.srt.managers.schedule_batch
import
FINISH_MATCHED_STR
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.utils
import
find_printable_text
,
get_exception_traceback
,
graceful_registry
from
sglang.utils
import
find_printable_text
,
get_exception_traceback
,
graceful_registry
...
@@ -66,6 +70,18 @@ class DetokenizerManager:
...
@@ -66,6 +70,18 @@ class DetokenizerManager:
async
def
handle_loop
(
self
):
async
def
handle_loop
(
self
):
while
True
:
while
True
:
recv_obj
:
BatchTokenIDOut
=
await
self
.
recv_from_router
.
recv_pyobj
()
recv_obj
:
BatchTokenIDOut
=
await
self
.
recv_from_router
.
recv_pyobj
()
if
isinstance
(
recv_obj
,
BatchEmbeddingOut
):
self
.
send_to_tokenizer
.
send_pyobj
(
BatchEmbeddingOut
(
rids
=
recv_obj
.
rids
,
embeddings
=
recv_obj
.
embeddings
,
meta_info
=
recv_obj
.
meta_info
,
finished_reason
=
recv_obj
.
finished_reason
,
)
)
continue
assert
isinstance
(
recv_obj
,
BatchTokenIDOut
)
assert
isinstance
(
recv_obj
,
BatchTokenIDOut
)
bs
=
len
(
recv_obj
.
rids
)
bs
=
len
(
recv_obj
.
rids
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
e040a245
...
@@ -143,6 +143,7 @@ class Req:
...
@@ -143,6 +143,7 @@ class Req:
# Logprobs
# Logprobs
self
.
return_logprob
=
False
self
.
return_logprob
=
False
self
.
embedding
=
None
self
.
logprob_start_len
=
0
self
.
logprob_start_len
=
0
self
.
top_logprobs_num
=
0
self
.
top_logprobs_num
=
0
self
.
normalized_prompt_logprob
=
None
self
.
normalized_prompt_logprob
=
None
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
e040a245
...
@@ -21,7 +21,7 @@ import dataclasses
...
@@ -21,7 +21,7 @@ import dataclasses
import
logging
import
logging
import
multiprocessing
as
mp
import
multiprocessing
as
mp
import
os
import
os
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
transformers
import
transformers
...
@@ -38,16 +38,19 @@ from sglang.srt.hf_transformers_utils import (
...
@@ -38,16 +38,19 @@ from sglang.srt.hf_transformers_utils import (
)
)
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
AbortReq
,
BatchEmbeddingOut
,
BatchStrOut
,
BatchStrOut
,
BatchTokenIDOut
,
BatchTokenIDOut
,
EmbeddingReqInput
,
FlushCacheReq
,
FlushCacheReq
,
GenerateReqInput
,
GenerateReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
)
)
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
is_multimodal_model
,
load_image
from
sglang.srt.utils
import
is_generation_model
,
is_multimodal_model
,
load_image
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
@@ -85,6 +88,7 @@ class TokenizerManager:
...
@@ -85,6 +88,7 @@ class TokenizerManager:
trust_remote_code
=
server_args
.
trust_remote_code
,
trust_remote_code
=
server_args
.
trust_remote_code
,
model_overide_args
=
model_overide_args
,
model_overide_args
=
model_overide_args
,
)
)
self
.
is_generation
=
is_generation_model
(
self
.
hf_config
.
architectures
)
if
server_args
.
context_length
is
not
None
:
if
server_args
.
context_length
is
not
None
:
self
.
context_len
=
server_args
.
context_length
self
.
context_len
=
server_args
.
context_length
...
@@ -133,7 +137,9 @@ class TokenizerManager:
...
@@ -133,7 +137,9 @@ class TokenizerManager:
image_data
,
aspect_ratio
,
grid_pinpoints
,
self
.
processor
image_data
,
aspect_ratio
,
grid_pinpoints
,
self
.
processor
)
)
async
def
generate_request
(
self
,
obj
:
GenerateReqInput
,
request
=
None
):
async
def
generate_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
=
None
):
if
self
.
to_create_loop
:
if
self
.
to_create_loop
:
self
.
create_handle_loop
()
self
.
create_handle_loop
()
...
@@ -144,6 +150,8 @@ class TokenizerManager:
...
@@ -144,6 +150,8 @@ class TokenizerManager:
async
for
response
in
self
.
_handle_single_request
(
obj
,
request
):
async
for
response
in
self
.
_handle_single_request
(
obj
,
request
):
yield
response
yield
response
else
:
else
:
if
isinstance
(
obj
,
EmbeddingReqInput
):
raise
NotImplementedError
(
"Please send only one prompt in each request"
)
if
obj
.
stream
:
if
obj
.
stream
:
raise
ValueError
(
"Do not support stream for batch mode."
)
raise
ValueError
(
"Do not support stream for batch mode."
)
...
@@ -151,39 +159,47 @@ class TokenizerManager:
...
@@ -151,39 +159,47 @@ class TokenizerManager:
yield
response
yield
response
async
def
_handle_single_request
(
async
def
_handle_single_request
(
self
,
obj
,
request
,
index
=
None
,
is_cache_for_prefill
=
False
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
,
index
=
None
,
is_cache_for_prefill
=
False
,
):
):
if
not
is_cache_for_prefill
:
# The normal case with a single prompt
if
not
is_cache_for_prefill
:
# The normal case with a single prompt
not_use_index
=
index
is
None
not_use_index
=
index
is
None
rid
=
obj
.
rid
if
not_use_index
else
obj
.
rid
[
index
]
rid
=
obj
.
rid
if
not_use_index
else
obj
.
rid
[
index
]
input_text
=
obj
.
text
if
not_use_index
else
obj
.
text
[
index
]
input_text
=
obj
.
text
if
not_use_index
else
obj
.
text
[
index
]
input_ids
=
(
if
obj
.
input_ids
is
None
:
self
.
tokenizer
.
encode
(
input_text
)
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
if
obj
.
input_ids
is
None
else
:
else
obj
.
input_ids
input_ids
=
obj
.
input_ids
if
not_use_index
else
obj
.
input_ids
[
index
]
)
if
not
not_use_index
and
obj
.
input_ids
:
input_ids
=
obj
.
input_ids
[
index
]
self
.
_validate_input_length
(
input_ids
)
self
.
_validate_input_length
(
input_ids
)
sampling_params
=
self
.
_get_sampling_params
(
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
if
not_use_index
else
obj
.
sampling_params
[
index
]
obj
.
sampling_params
if
not_use_index
else
obj
.
sampling_params
[
index
]
)
)
pixel_values
,
image_hash
,
image_size
=
await
self
.
_get_pixel_values
(
obj
.
image_data
if
not_use_index
else
obj
.
image_data
[
index
]
if
self
.
is_generation
:
)
pixel_values
,
image_hash
,
image_size
=
await
self
.
_get_pixel_values
(
return_logprob
=
(
obj
.
image_data
if
not_use_index
else
obj
.
image_data
[
index
]
obj
.
return_logprob
if
not_use_index
else
obj
.
return_logprob
[
index
]
)
)
return_logprob
=
(
logprob_start_len
=
(
obj
.
return_logprob
if
not_use_index
else
obj
.
return_logprob
[
index
]
obj
.
logprob_start_len
if
not_use_index
else
obj
.
logprob_start_len
[
index
]
)
)
logprob_start_len
=
(
top_logprobs_num
=
(
obj
.
logprob_start_len
obj
.
top_logprobs_num
if
not_use_index
else
obj
.
top_logprobs_num
[
index
]
if
not_use_index
)
else
obj
.
logprob_start_len
[
index
]
)
top_logprobs_num
=
(
obj
.
top_logprobs_num
if
not_use_index
else
obj
.
top_logprobs_num
[
index
]
)
else
:
# A prefill request to cache the common prompt for parallel sampling
else
:
# A prefill request to cache the common prompt for parallel sampling
assert
self
.
is_generation
if
obj
.
text
is
not
None
:
if
obj
.
text
is
not
None
:
if
isinstance
(
obj
.
text
,
list
):
if
isinstance
(
obj
.
text
,
list
):
input_text
=
obj
.
text
[
index
]
input_text
=
obj
.
text
[
index
]
...
@@ -213,19 +229,28 @@ class TokenizerManager:
...
@@ -213,19 +229,28 @@ class TokenizerManager:
logprob_start_len
=
obj
.
logprob_start_len
[
0
]
logprob_start_len
=
obj
.
logprob_start_len
[
0
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
0
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
0
]
tokenized_obj
=
TokenizedGenerateReqInput
(
if
self
.
is_generation
:
rid
,
tokenized_obj
=
TokenizedGenerateReqInput
(
input_text
,
rid
,
input_ids
,
input_text
,
pixel_values
,
input_ids
,
image_hash
,
pixel_values
,
image_size
,
image_hash
,
sampling_params
,
image_size
,
return_logprob
,
sampling_params
,
logprob_start_len
,
return_logprob
,
top_logprobs_num
,
logprob_start_len
,
obj
.
stream
,
top_logprobs_num
,
)
obj
.
stream
,
)
else
:
# is embedding
tokenized_obj
=
TokenizedEmbeddingReqInput
(
rid
,
input_text
,
input_ids
,
sampling_params
,
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
event
=
asyncio
.
Event
()
event
=
asyncio
.
Event
()
...
@@ -368,7 +393,7 @@ class TokenizerManager:
...
@@ -368,7 +393,7 @@ class TokenizerManager:
self
,
self
,
event
:
asyncio
.
Event
,
event
:
asyncio
.
Event
,
state
:
ReqState
,
state
:
ReqState
,
obj
:
GenerateReqInput
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
rid
:
str
,
rid
:
str
,
request
,
request
,
):
):
...
@@ -381,12 +406,15 @@ class TokenizerManager:
...
@@ -381,12 +406,15 @@ class TokenizerManager:
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
continue
continue
out
=
self
.
convert_logprob_style
(
if
self
.
is_generation
:
state
.
out_list
[
-
1
],
out
=
self
.
convert_logprob_style
(
obj
.
return_logprob
,
state
.
out_list
[
-
1
],
obj
.
top_logprobs_num
,
obj
.
return_logprob
,
obj
.
return_text_in_logprobs
,
obj
.
top_logprobs_num
,
)
obj
.
return_text_in_logprobs
,
)
else
:
# isinstance(obj, EmbeddingReqInput)
out
=
state
.
out_list
[
-
1
]
# Log requests
# Log requests
if
self
.
server_args
.
log_requests
and
state
.
finished
:
if
self
.
server_args
.
log_requests
and
state
.
finished
:
...
@@ -459,8 +487,10 @@ class TokenizerManager:
...
@@ -459,8 +487,10 @@ class TokenizerManager:
async
def
handle_loop
(
self
):
async
def
handle_loop
(
self
):
while
True
:
while
True
:
recv_obj
:
BatchStrOut
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
recv_obj
:
Union
[
BatchStrOut
,
BatchEmbeddingOut
]
=
(
assert
isinstance
(
recv_obj
,
BatchStrOut
)
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
)
assert
isinstance
(
recv_obj
,
(
BatchStrOut
,
BatchEmbeddingOut
))
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
state
=
self
.
rid_to_state
.
get
(
rid
,
None
)
state
=
self
.
rid_to_state
.
get
(
rid
,
None
)
...
@@ -468,10 +498,17 @@ class TokenizerManager:
...
@@ -468,10 +498,17 @@ class TokenizerManager:
continue
continue
recv_obj
.
meta_info
[
i
][
"id"
]
=
rid
recv_obj
.
meta_info
[
i
][
"id"
]
=
rid
out_dict
=
{
if
isinstance
(
recv_obj
,
BatchStrOut
):
"text"
:
recv_obj
.
output_strs
[
i
],
out_dict
=
{
"meta_info"
:
recv_obj
.
meta_info
[
i
],
"text"
:
recv_obj
.
output_strs
[
i
],
}
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
else
:
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
out_dict
=
{
"embedding"
:
recv_obj
.
embeddings
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
state
.
out_list
.
append
(
out_dict
)
state
.
out_list
.
append
(
out_dict
)
state
.
finished
=
recv_obj
.
finished_reason
[
i
]
is
not
None
state
.
finished
=
recv_obj
.
finished_reason
[
i
]
is
not
None
state
.
event
.
set
()
state
.
event
.
set
()
...
...
python/sglang/srt/managers/tp_worker.py
View file @
e040a245
...
@@ -20,7 +20,7 @@ import multiprocessing
...
@@ -20,7 +20,7 @@ import multiprocessing
import
pickle
import
pickle
import
time
import
time
import
warnings
import
warnings
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -31,8 +31,10 @@ from sglang.srt.constrained.jump_forward import JumpForwardCache
...
@@ -31,8 +31,10 @@ from sglang.srt.constrained.jump_forward import JumpForwardCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
AbortReq
,
BatchEmbeddingOut
,
BatchTokenIDOut
,
BatchTokenIDOut
,
FlushCacheReq
,
FlushCacheReq
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
)
)
from
sglang.srt.managers.policy_scheduler
import
PolicyScheduler
,
PrefillAdder
from
sglang.srt.managers.policy_scheduler
import
PolicyScheduler
,
PrefillAdder
...
@@ -205,7 +207,9 @@ class ModelTpServer:
...
@@ -205,7 +207,9 @@ class ModelTpServer:
try
:
try
:
# Recv requests
# Recv requests
for
recv_req
in
recv_reqs
:
for
recv_req
in
recv_reqs
:
if
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
if
isinstance
(
recv_req
,
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
)
):
self
.
handle_generate_request
(
recv_req
)
self
.
handle_generate_request
(
recv_req
)
elif
isinstance
(
recv_req
,
FlushCacheReq
):
elif
isinstance
(
recv_req
,
FlushCacheReq
):
self
.
flush_cache
()
self
.
flush_cache
()
...
@@ -297,41 +301,42 @@ class ModelTpServer:
...
@@ -297,41 +301,42 @@ class ModelTpServer:
def
handle_generate_request
(
def
handle_generate_request
(
self
,
self
,
recv_req
:
TokenizedGenerateReqInput
,
recv_req
:
Union
[
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
],
):
):
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
)
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
)
req
.
pixel_values
=
recv_req
.
pixel_values
if
req
.
pixel_values
is
not
None
:
req
.
pad_value
=
[
(
recv_req
.
image_hash
)
%
self
.
model_config
.
vocab_size
,
(
recv_req
.
image_hash
>>
16
)
%
self
.
model_config
.
vocab_size
,
(
recv_req
.
image_hash
>>
32
)
%
self
.
model_config
.
vocab_size
,
(
recv_req
.
image_hash
>>
64
)
%
self
.
model_config
.
vocab_size
,
]
req
.
image_size
=
recv_req
.
image_size
(
req
.
origin_input_ids
,
req
.
image_offset
,
)
=
self
.
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids_unpadded
,
req
.
pad_value
,
req
.
pixel_values
.
shape
,
req
.
image_size
,
)
req
.
sampling_params
=
recv_req
.
sampling_params
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
logprob_start_len
=
recv_req
.
logprob_start_len
req
.
top_logprobs_num
=
recv_req
.
top_logprobs_num
req
.
stream
=
recv_req
.
stream
req
.
tokenizer
=
self
.
tokenizer
req
.
tokenizer
=
self
.
tokenizer
req
.
sampling_params
=
recv_req
.
sampling_params
# Init regex fsm
if
self
.
model_runner
.
is_generation
:
if
req
.
sampling_params
.
regex
is
not
None
:
req
.
pixel_values
=
recv_req
.
pixel_values
req
.
regex_fsm
=
self
.
regex_fsm_cache
.
query
(
req
.
sampling_params
.
regex
)
if
req
.
pixel_values
is
not
None
:
if
not
self
.
disable_regex_jump_forward
:
req
.
pad_value
=
[
req
.
jump_forward_map
=
self
.
jump_forward_cache
.
query
(
(
recv_req
.
image_hash
)
%
self
.
model_config
.
vocab_size
,
req
.
sampling_params
.
regex
(
recv_req
.
image_hash
>>
16
)
%
self
.
model_config
.
vocab_size
,
(
recv_req
.
image_hash
>>
32
)
%
self
.
model_config
.
vocab_size
,
(
recv_req
.
image_hash
>>
64
)
%
self
.
model_config
.
vocab_size
,
]
req
.
image_size
=
recv_req
.
image_size
(
req
.
origin_input_ids
,
req
.
image_offset
,
)
=
self
.
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids_unpadded
,
req
.
pad_value
,
req
.
pixel_values
.
shape
,
req
.
image_size
,
)
)
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
logprob_start_len
=
recv_req
.
logprob_start_len
req
.
top_logprobs_num
=
recv_req
.
top_logprobs_num
req
.
stream
=
recv_req
.
stream
# Init regex fsm
if
req
.
sampling_params
.
regex
is
not
None
:
req
.
regex_fsm
=
self
.
regex_fsm_cache
.
query
(
req
.
sampling_params
.
regex
)
if
not
self
.
disable_regex_jump_forward
:
req
.
jump_forward_map
=
self
.
jump_forward_cache
.
query
(
req
.
sampling_params
.
regex
)
# Truncate prompts that are too long
# Truncate prompts that are too long
if
len
(
req
.
origin_input_ids
)
>=
self
.
max_req_input_len
:
if
len
(
req
.
origin_input_ids
)
>=
self
.
max_req_input_len
:
...
@@ -340,14 +345,17 @@ class ModelTpServer:
...
@@ -340,14 +345,17 @@ class ModelTpServer:
"the max context length. Truncated!!!"
"the max context length. Truncated!!!"
)
)
req
.
origin_input_ids
=
req
.
origin_input_ids
[:
self
.
max_req_input_len
]
req
.
origin_input_ids
=
req
.
origin_input_ids
[:
self
.
max_req_input_len
]
req
.
sampling_params
.
max_new_tokens
=
min
(
(
if
self
.
model_runner
.
is_generation
:
req
.
sampling_params
.
max_new_tokens
req
.
sampling_params
.
max_new_tokens
=
min
(
if
req
.
sampling_params
.
max_new_tokens
is
not
None
(
else
1
<<
30
req
.
sampling_params
.
max_new_tokens
),
if
req
.
sampling_params
.
max_new_tokens
is
not
None
self
.
max_req_input_len
-
1
-
len
(
req
.
origin_input_ids
),
else
1
<<
30
)
),
self
.
max_req_input_len
-
1
-
len
(
req
.
origin_input_ids
),
)
self
.
waiting_queue
.
append
(
req
)
self
.
waiting_queue
.
append
(
req
)
def
get_new_prefill_batch
(
self
)
->
Optional
[
ScheduleBatch
]:
def
get_new_prefill_batch
(
self
)
->
Optional
[
ScheduleBatch
]:
...
@@ -439,47 +447,68 @@ class ModelTpServer:
...
@@ -439,47 +447,68 @@ class ModelTpServer:
self
.
model_config
.
vocab_size
,
self
.
int_token_logit_bias
self
.
model_config
.
vocab_size
,
self
.
int_token_logit_bias
)
)
# Forward and sample the next tokens
if
self
.
model_runner
.
is_generation
:
if
batch
.
extend_num_tokens
!=
0
:
# Forward and sample the next tokens
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
if
batch
.
extend_num_tokens
!=
0
:
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
# Move logprobs to cpu
output
.
next_token_logprobs
=
output
.
next_token_logprobs
[
if
output
.
next_token_logprobs
is
not
None
:
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
output
.
next_token_logprobs
=
output
.
next_token_logprobs
[
next_token_ids
,
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
].
tolist
()
next_token_ids
,
output
.
input_token_logprobs
=
output
.
input_token_logprobs
.
tolist
()
].
tolist
()
output
.
normalized_prompt_logprobs
=
(
output
.
input_token_logprobs
=
output
.
input_token_logprobs
.
tolist
()
output
.
normalized_prompt_logprobs
.
tolist
()
output
.
normalized_prompt_logprobs
=
(
)
output
.
normalized_prompt_logprobs
.
tolist
()
)
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
else
:
else
:
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
# Check finish conditions
pt
=
0
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
is
not
self
.
current_inflight_req
:
# Inflight reqs' prefill is not finished
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_ids
[
i
])
req
.
check_finished
()
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
else
:
self
.
tree_cache
.
cache_unfinished_req
(
req
)
# Check finish conditions
if
req
is
self
.
current_inflight_req
:
pt
=
0
# Inflight request would get a new req idx
for
i
,
req
in
enumerate
(
batch
.
reqs
):
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
if
req
is
not
self
.
current_inflight_req
:
# Inflight reqs' prefill is not finished
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_ids
[
i
])
req
.
check_finished
()
if
req
.
finished
():
if
req
.
return_logprob
:
self
.
tree_cache
.
cache_finished_req
(
req
)
self
.
add_logprob_return_values
(
i
,
req
,
pt
,
next_token_ids
,
output
)
else
:
pt
+=
req
.
extend_input_len
self
.
tree_cache
.
cache_unfinished_req
(
req
)
else
:
assert
batch
.
extend_num_tokens
!=
0
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
embeddings
=
output
.
embeddings
.
tolist
()
if
req
is
self
.
current_inflight_req
:
# Check finish conditions
# Inflight request would get a new req idx
for
i
,
req
in
enumerate
(
batch
.
reqs
):
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
req
.
embedding
=
embeddings
[
i
]
if
req
is
not
self
.
current_inflight_req
:
# Inflight reqs' prefill is not finished
req
.
check_finished
()
if
req
.
return_logprob
:
if
req
.
finished
():
self
.
add_logprob_return_values
(
i
,
req
,
pt
,
next_token_ids
,
output
)
self
.
tree_cache
.
cache_finished_req
(
req
)
pt
+=
req
.
extend_input_len
else
:
self
.
tree_cache
.
cache_unfinished_req
(
req
)
if
req
is
self
.
current_inflight_req
:
# Inflight request would get a new req idx
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
handle_finished_requests
(
batch
)
self
.
handle_finished_requests
(
batch
)
...
@@ -596,15 +625,19 @@ class ModelTpServer:
...
@@ -596,15 +625,19 @@ class ModelTpServer:
def
handle_finished_requests
(
self
,
batch
:
ScheduleBatch
):
def
handle_finished_requests
(
self
,
batch
:
ScheduleBatch
):
output_rids
=
[]
output_rids
=
[]
output_vids
=
[]
decoded_texts
=
[]
output_read_ids
=
[]
output_read_offsets
=
[]
output_skip_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
output_meta_info
=
[]
output_meta_info
=
[]
output_finished_reason
:
List
[
BaseFinishReason
]
=
[]
output_finished_reason
:
List
[
BaseFinishReason
]
=
[]
if
self
.
model_runner
.
is_generation
:
output_vids
=
[]
decoded_texts
=
[]
output_read_ids
=
[]
output_read_offsets
=
[]
output_skip_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
else
:
# for embedding model
output_embeddings
=
[]
unfinished_indices
=
[]
unfinished_indices
=
[]
for
i
,
req
in
enumerate
(
batch
.
reqs
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
not
req
.
finished
()
and
req
is
not
self
.
current_inflight_req
:
if
not
req
.
finished
()
and
req
is
not
self
.
current_inflight_req
:
unfinished_indices
.
append
(
i
)
unfinished_indices
.
append
(
i
)
...
@@ -619,56 +652,73 @@ class ModelTpServer:
...
@@ -619,56 +652,73 @@ class ModelTpServer:
)
)
):
):
output_rids
.
append
(
req
.
rid
)
output_rids
.
append
(
req
.
rid
)
output_vids
.
append
(
req
.
vid
)
decoded_texts
.
append
(
req
.
decoded_text
)
read_ids
,
read_offset
=
req
.
init_incremental_detokenize
()
output_read_ids
.
append
(
read_ids
)
output_read_offsets
.
append
(
read_offset
)
output_skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
)
output_spaces_between_special_tokens
.
append
(
req
.
sampling_params
.
spaces_between_special_tokens
)
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
"completion_tokens"
:
len
(
req
.
output_ids
),
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
,
"finish_reason"
:
str
(
req
.
finished_reason
),
}
if
req
.
return_logprob
:
(
meta_info
[
"input_token_logprobs"
],
meta_info
[
"output_token_logprobs"
],
meta_info
[
"input_top_logprobs"
],
meta_info
[
"output_top_logprobs"
],
meta_info
[
"normalized_prompt_logprob"
],
)
=
(
req
.
input_token_logprobs
,
req
.
output_token_logprobs
,
req
.
input_top_logprobs
,
req
.
output_top_logprobs
,
req
.
normalized_prompt_logprob
,
)
output_meta_info
.
append
(
meta_info
)
output_finished_reason
.
append
(
req
.
finished_reason
)
output_finished_reason
.
append
(
req
.
finished_reason
)
if
self
.
model_runner
.
is_generation
:
output_vids
.
append
(
req
.
vid
)
decoded_texts
.
append
(
req
.
decoded_text
)
read_ids
,
read_offset
=
req
.
init_incremental_detokenize
()
output_read_ids
.
append
(
read_ids
)
output_read_offsets
.
append
(
read_offset
)
output_skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
)
output_spaces_between_special_tokens
.
append
(
req
.
sampling_params
.
spaces_between_special_tokens
)
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
"completion_tokens"
:
len
(
req
.
output_ids
),
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
,
"finish_reason"
:
str
(
req
.
finished_reason
),
}
if
req
.
return_logprob
:
(
meta_info
[
"input_token_logprobs"
],
meta_info
[
"output_token_logprobs"
],
meta_info
[
"input_top_logprobs"
],
meta_info
[
"output_top_logprobs"
],
meta_info
[
"normalized_prompt_logprob"
],
)
=
(
req
.
input_token_logprobs
,
req
.
output_token_logprobs
,
req
.
input_top_logprobs
,
req
.
output_top_logprobs
,
req
.
normalized_prompt_logprob
,
)
output_meta_info
.
append
(
meta_info
)
else
:
# for embedding model
output_embeddings
.
append
(
req
.
embedding
)
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
}
output_meta_info
.
append
(
meta_info
)
# Send to detokenizer
# Send to detokenizer
if
output_rids
:
if
output_rids
:
self
.
out_pyobjs
.
append
(
if
self
.
model_runner
.
is_generation
:
BatchTokenIDOut
(
self
.
out_pyobjs
.
append
(
output_rids
,
BatchTokenIDOut
(
output_vids
,
output_rids
,
decoded_texts
,
output_vids
,
output_read_ids
,
decoded_texts
,
output_read_offsets
,
output_read_ids
,
output_skip_special_tokens
,
output_read_offsets
,
output_spaces_between_special_tokens
,
output_skip_special_tokens
,
output_meta_info
,
output_spaces_between_special_tokens
,
output_finished_reason
,
output_meta_info
,
output_finished_reason
,
)
)
else
:
# for embedding model
self
.
out_pyobjs
.
append
(
BatchEmbeddingOut
(
output_rids
,
output_embeddings
,
output_meta_info
,
output_finished_reason
,
)
)
)
)
# Remove finished reqs: update batch tensors
# Remove finished reqs: update batch tensors
batch
.
filter_batch
(
unfinished_indices
)
batch
.
filter_batch
(
unfinished_indices
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
e040a245
...
@@ -52,6 +52,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
...
@@ -52,6 +52,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_available_gpu_memory
,
get_available_gpu_memory
,
is_generation_model
,
is_llama3_405b_fp8
,
is_llama3_405b_fp8
,
is_multimodal_model
,
is_multimodal_model
,
monkey_patch_vllm_dummy_weight_loader
,
monkey_patch_vllm_dummy_weight_loader
,
...
@@ -132,8 +133,10 @@ class ModelRunner:
...
@@ -132,8 +133,10 @@ class ModelRunner:
self
.
init_cublas
()
self
.
init_cublas
()
self
.
init_flashinfer
()
self
.
init_flashinfer
()
# Capture cuda graphs
if
self
.
is_generation
:
self
.
init_cuda_graphs
()
# FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
# Capture cuda graphs
self
.
init_cuda_graphs
()
def
load_model
(
self
):
def
load_model
(
self
):
logger
.
info
(
logger
.
info
(
...
@@ -184,6 +187,10 @@ class ModelRunner:
...
@@ -184,6 +187,10 @@ class ModelRunner:
scheduler_config
=
None
,
scheduler_config
=
None
,
cache_config
=
None
,
cache_config
=
None
,
)
)
self
.
is_generation
=
is_generation_model
(
self
.
model_config
.
hf_config
.
architectures
)
logger
.
info
(
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
] Load weight end. "
f
"[gpu=
{
self
.
gpu_id
}
] Load weight end. "
f
"type=
{
type
(
self
.
model
).
__name__
}
, "
f
"type=
{
type
(
self
.
model
).
__name__
}
, "
...
@@ -406,8 +413,10 @@ def import_model_classes():
...
@@ -406,8 +413,10 @@ def import_model_classes():
entry
,
list
entry
,
list
):
# To support multiple model classes in one module
):
# To support multiple model classes in one module
for
tmp
in
entry
:
for
tmp
in
entry
:
assert
tmp
.
__name__
not
in
model_arch_name_to_cls
model_arch_name_to_cls
[
tmp
.
__name__
]
=
tmp
model_arch_name_to_cls
[
tmp
.
__name__
]
=
tmp
else
:
else
:
assert
entry
.
__name__
not
in
model_arch_name_to_cls
model_arch_name_to_cls
[
entry
.
__name__
]
=
entry
model_arch_name_to_cls
[
entry
.
__name__
]
=
entry
# compat: some models such as chatglm has incorrect class set in config.json
# compat: some models such as chatglm has incorrect class set in config.json
...
@@ -417,6 +426,7 @@ def import_model_classes():
...
@@ -417,6 +426,7 @@ def import_model_classes():
):
):
for
remap
in
module
.
EntryClassRemapping
:
for
remap
in
module
.
EntryClassRemapping
:
if
isinstance
(
remap
,
tuple
)
and
len
(
remap
)
==
2
:
if
isinstance
(
remap
,
tuple
)
and
len
(
remap
)
==
2
:
assert
remap
[
0
]
not
in
model_arch_name_to_cls
model_arch_name_to_cls
[
remap
[
0
]]
=
remap
[
1
]
model_arch_name_to_cls
[
remap
[
0
]]
=
remap
[
1
]
return
model_arch_name_to_cls
return
model_arch_name_to_cls
...
...
python/sglang/srt/models/llama_embedding.py
View file @
e040a245
...
@@ -84,3 +84,5 @@ class LlamaEmbeddingModel(nn.Module):
...
@@ -84,3 +84,5 @@ class LlamaEmbeddingModel(nn.Module):
EntryClass
=
LlamaEmbeddingModel
EntryClass
=
LlamaEmbeddingModel
# compat: e5-mistral model.config class == MistralModel
EntryClassRemapping
=
[(
"MistralModel"
,
LlamaEmbeddingModel
)]
python/sglang/srt/server.py
View file @
e040a245
...
@@ -52,7 +52,7 @@ from sglang.srt.managers.controller_single import (
...
@@ -52,7 +52,7 @@ from sglang.srt.managers.controller_single import (
start_controller_process
as
start_controller_process_single
,
start_controller_process
as
start_controller_process_single
,
)
)
from
sglang.srt.managers.detokenizer_manager
import
start_detokenizer_process
from
sglang.srt.managers.detokenizer_manager
import
start_detokenizer_process
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.managers.io_struct
import
EmbeddingReqInput
,
GenerateReqInput
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.openai_api.adapter
import
(
from
sglang.srt.openai_api.adapter
import
(
load_chat_template_for_openai_api
,
load_chat_template_for_openai_api
,
...
@@ -97,6 +97,7 @@ async def health() -> Response:
...
@@ -97,6 +97,7 @@ async def health() -> Response:
async
def
get_model_info
():
async
def
get_model_info
():
result
=
{
result
=
{
"model_path"
:
tokenizer_manager
.
model_path
,
"model_path"
:
tokenizer_manager
.
model_path
,
"is_generation"
:
tokenizer_manager
.
is_generation
,
}
}
return
result
return
result
...
@@ -148,6 +149,21 @@ app.post("/generate")(generate_request)
...
@@ -148,6 +149,21 @@ app.post("/generate")(generate_request)
app
.
put
(
"/generate"
)(
generate_request
)
app
.
put
(
"/generate"
)(
generate_request
)
async
def
encode_request
(
obj
:
EmbeddingReqInput
,
request
:
Request
):
"""Handle an embedding request."""
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
return
ret
except
ValueError
as
e
:
return
JSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
app
.
post
(
"/encode"
)(
encode_request
)
app
.
put
(
"/encode"
)(
encode_request
)
@
app
.
post
(
"/v1/completions"
)
@
app
.
post
(
"/v1/completions"
)
async
def
openai_v1_completions
(
raw_request
:
Request
):
async
def
openai_v1_completions
(
raw_request
:
Request
):
return
await
v1_completions
(
tokenizer_manager
,
raw_request
)
return
await
v1_completions
(
tokenizer_manager
,
raw_request
)
...
@@ -380,6 +396,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
...
@@ -380,6 +396,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
except
(
AssertionError
,
requests
.
exceptions
.
RequestException
)
as
e
:
except
(
AssertionError
,
requests
.
exceptions
.
RequestException
)
as
e
:
last_traceback
=
get_exception_traceback
()
last_traceback
=
get_exception_traceback
()
pass
pass
model_info
=
res
.
json
()
if
not
success
:
if
not
success
:
if
pipe_finish_writer
is
not
None
:
if
pipe_finish_writer
is
not
None
:
...
@@ -388,15 +405,17 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
...
@@ -388,15 +405,17 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
sys
.
exit
(
1
)
sys
.
exit
(
1
)
# Send a warmup request
# Send a warmup request
request_name
=
"/generate"
if
model_info
[
"is_generation"
]
else
"/encode"
max_new_tokens
=
8
if
model_info
[
"is_generation"
]
else
0
try
:
try
:
for
_
in
range
(
server_args
.
dp_size
):
for
_
in
range
(
server_args
.
dp_size
):
res
=
requests
.
post
(
res
=
requests
.
post
(
url
+
"/generate"
,
url
+
request_name
,
json
=
{
json
=
{
"text"
:
"The capital city of France is"
,
"text"
:
"The capital city of France is"
,
"sampling_params"
:
{
"sampling_params"
:
{
"temperature"
:
0
,
"temperature"
:
0
,
"max_new_tokens"
:
8
,
"max_new_tokens"
:
max_new_tokens
,
},
},
},
},
headers
=
headers
,
headers
=
headers
,
...
@@ -529,5 +548,18 @@ class Runtime:
...
@@ -529,5 +548,18 @@ class Runtime:
)
)
return
json
.
dumps
(
response
.
json
())
return
json
.
dumps
(
response
.
json
())
def
encode
(
self
,
prompt
:
str
,
):
json_data
=
{
"text"
:
prompt
,
}
response
=
requests
.
post
(
self
.
url
+
"/encode"
,
json
=
json_data
,
)
return
json
.
dumps
(
response
.
json
())
def
__del__
(
self
):
def
__del__
(
self
):
self
.
shutdown
()
self
.
shutdown
()
python/sglang/srt/utils.py
View file @
e040a245
...
@@ -223,6 +223,15 @@ def is_multimodal_model(model):
...
@@ -223,6 +223,15 @@ def is_multimodal_model(model):
raise
ValueError
(
"unrecognized type"
)
raise
ValueError
(
"unrecognized type"
)
def
is_generation_model
(
model_architectures
):
if
(
"LlamaEmbeddingModel"
in
model_architectures
or
"MistralModel"
in
model_architectures
):
return
False
return
True
def
decode_video_base64
(
video_base64
):
def
decode_video_base64
(
video_base64
):
from
PIL
import
Image
from
PIL
import
Image
...
...
python/sglang/test/runners.py
View file @
e040a245
...
@@ -23,6 +23,7 @@ import torch.nn.functional as F
...
@@ -23,6 +23,7 @@ import torch.nn.functional as F
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
sglang.srt.server
import
Runtime
from
sglang.srt.server
import
Runtime
from
sglang.srt.utils
import
is_generation_model
DEFAULT_PROMPTS
=
[
DEFAULT_PROMPTS
=
[
"The capital of France is"
,
"The capital of France is"
,
...
@@ -33,13 +34,6 @@ DEFAULT_PROMPTS = [
...
@@ -33,13 +34,6 @@ DEFAULT_PROMPTS = [
NUM_TOP_LOGPROBS
=
5
NUM_TOP_LOGPROBS
=
5
def
is_embedding_model
(
model_path
):
# FIXME incomplete list
if
"e5-mistral-7b-instruct"
in
model_path
.
lower
():
return
True
return
False
def
get_dtype_str
(
torch_dtype
):
def
get_dtype_str
(
torch_dtype
):
if
torch_dtype
is
torch
.
float16
:
if
torch_dtype
is
torch
.
float16
:
return
"float16"
return
"float16"
...
@@ -60,7 +54,7 @@ class HFRunner:
...
@@ -60,7 +54,7 @@ class HFRunner:
self
,
self
,
model_path
,
model_path
,
torch_dtype
=
torch
.
float16
,
torch_dtype
=
torch
.
float16
,
is_
embedding
_model
=
None
,
is_
generation
_model
=
None
,
):
):
self
.
in_queue
=
multiprocessing
.
Queue
()
self
.
in_queue
=
multiprocessing
.
Queue
()
self
.
out_queue
=
multiprocessing
.
Queue
()
self
.
out_queue
=
multiprocessing
.
Queue
()
...
@@ -72,13 +66,13 @@ class HFRunner:
...
@@ -72,13 +66,13 @@ class HFRunner:
self
.
out_queue
,
self
.
out_queue
,
model_path
,
model_path
,
torch_dtype
,
torch_dtype
,
is_
embedding
_model
,
is_
generation
_model
,
),
),
)
)
self
.
model_proc
.
start
()
self
.
model_proc
.
start
()
def
start_model_process
(
def
start_model_process
(
self
,
in_queue
,
out_queue
,
model_path
,
torch_dtype
,
is_
embedding
_model
self
,
in_queue
,
out_queue
,
model_path
,
torch_dtype
,
is_
generation
_model
):
):
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
model_path
,
...
@@ -86,12 +80,12 @@ class HFRunner:
...
@@ -86,12 +80,12 @@ class HFRunner:
trust_remote_code
=
True
,
trust_remote_code
=
True
,
)
)
self
.
is_
embedding
_model
=
(
self
.
is_
generation
_model
=
(
is_
embedding
_model
(
model_path
)
is_
generation
_model
(
model_path
)
if
is_
embedding
_model
is
None
if
is_
generation
_model
is
None
else
is_
embedding
_model
else
is_
generation
_model
)
)
if
not
self
.
is_
embedding
_model
:
if
self
.
is_
generation
_model
:
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
model_path
,
torch_dtype
=
torch_dtype
,
torch_dtype
=
torch_dtype
,
...
@@ -103,13 +97,13 @@ class HFRunner:
...
@@ -103,13 +97,13 @@ class HFRunner:
self
.
model
=
SentenceTransformer
(
self
.
model
=
SentenceTransformer
(
model_path
,
model_path
,
device
=
"cpu"
,
model_kwargs
=
{
"torch_dtype"
:
torch_dtype
}
,
)
.
to
(
dtype
=
torch_dtype
)
)
while
True
:
while
True
:
prompts
,
max_new_tokens
=
in_queue
.
get
()
prompts
,
max_new_tokens
=
in_queue
.
get
()
if
prompts
is
not
None
:
if
prompts
is
not
None
:
if
not
self
.
is_
embedding
_model
:
if
self
.
is_
generation
_model
:
output_strs
=
[]
output_strs
=
[]
prefill_logprobs
=
[]
prefill_logprobs
=
[]
for
p
in
prompts
:
for
p
in
prompts
:
...
@@ -144,7 +138,6 @@ class HFRunner:
...
@@ -144,7 +138,6 @@ class HFRunner:
)
)
else
:
else
:
assert
isinstance
(
prompts
,
List
[
str
])
logits
=
self
.
model
.
encode
(
prompts
).
tolist
()
logits
=
self
.
model
.
encode
(
prompts
).
tolist
()
out_queue
.
put
(
ModelOutput
(
embed_logits
=
logits
))
out_queue
.
put
(
ModelOutput
(
embed_logits
=
logits
))
...
@@ -175,16 +168,13 @@ class SRTRunner:
...
@@ -175,16 +168,13 @@ class SRTRunner:
model_path
,
model_path
,
tp_size
=
1
,
tp_size
=
1
,
torch_dtype
=
torch
.
float16
,
torch_dtype
=
torch
.
float16
,
is_
embedding
_model
=
None
,
is_
generation
_model
=
None
,
):
):
self
.
is_
embedding
_model
=
(
self
.
is_
generation
_model
=
(
is_
embedding
_model
(
model_path
)
is_
generation
_model
(
model_path
)
if
is_
embedding
_model
is
None
if
is_
generation
_model
is
None
else
is_
embedding
_model
else
is_
generation
_model
)
)
if
self
.
is_embedding_model
:
raise
NotImplementedError
()
self
.
runtime
=
Runtime
(
self
.
runtime
=
Runtime
(
model_path
=
model_path
,
model_path
=
model_path
,
tp_size
=
tp_size
,
tp_size
=
tp_size
,
...
@@ -196,38 +186,45 @@ class SRTRunner:
...
@@ -196,38 +186,45 @@ class SRTRunner:
prompts
:
Union
[
List
[
str
],
List
[
torch
.
Tensor
]]
=
DEFAULT_PROMPTS
,
prompts
:
Union
[
List
[
str
],
List
[
torch
.
Tensor
]]
=
DEFAULT_PROMPTS
,
max_new_tokens
=
64
,
max_new_tokens
=
64
,
):
):
# the return value contains logprobs from prefill
if
self
.
is_generation_model
:
output_strs
=
[]
# the return value contains logprobs from prefill
top_input_logprobs
=
[]
output_strs
=
[]
sampling_params
=
{
"max_new_tokens"
:
max_new_tokens
,
"temperature"
:
0
}
top_input_logprobs
=
[]
for
prompt
in
prompts
:
sampling_params
=
{
"max_new_tokens"
:
max_new_tokens
,
"temperature"
:
0
}
response
=
self
.
runtime
.
generate
(
for
prompt
in
prompts
:
prompt
,
response
=
self
.
runtime
.
generate
(
sampling_params
=
sampling_params
,
prompt
,
return_logprob
=
True
,
sampling_params
=
sampling_params
,
top_logprobs_num
=
NUM_TOP_LOGPROBS
,
return_logprob
=
True
,
)
top_logprobs_num
=
NUM_TOP_LOGPROBS
,
response
=
json
.
loads
(
response
)
)
output_strs
.
append
(
response
[
"text"
])
response
=
json
.
loads
(
response
)
top_input_logprobs
.
append
(
output_strs
.
append
(
response
[
"text"
])
[
top_input_logprobs
.
append
(
[
tup
[
0
]
for
tup
in
x
[:
NUM_TOP_LOGPROBS
]]
for
x
in
response
[
"meta_info"
][
"input_top_logprobs"
][
1
:]
]
+
[
[
[
tup
[
0
]
[
tup
[
0
]
for
tup
in
x
[:
NUM_TOP_LOGPROBS
]]
for
tup
in
response
[
"meta_info"
][
"output_top_logprobs"
][
0
][
for
x
in
response
[
"meta_info"
][
"input_top_logprobs"
][
1
:]
:
NUM_TOP_LOGPROBS
]
+
[
[
tup
[
0
]
for
tup
in
response
[
"meta_info"
][
"output_top_logprobs"
][
0
][
:
NUM_TOP_LOGPROBS
]
]
]
]
]
]
)
)
# print(response["meta_info"]["output_top_logprobs"][0])
return
ModelOutput
(
return
ModelOutput
(
output_strs
=
output_strs
,
top_input_logprobs
=
top_input_logprobs
output_strs
=
output_strs
,
top_input_logprobs
=
top_input_logprobs
)
)
else
:
logits
=
[]
for
prompt
in
prompts
:
response
=
self
.
runtime
.
encode
(
prompt
)
response
=
json
.
loads
(
response
)
logits
.
append
(
response
[
"embedding"
])
return
ModelOutput
(
embed_logits
=
logits
)
def
__enter__
(
self
):
def
__enter__
(
self
):
return
self
return
self
...
...
python/sglang/test/test_utils.py
View file @
e040a245
...
@@ -12,6 +12,8 @@ from typing import Callable, List, Optional
...
@@ -12,6 +12,8 @@ from typing import Callable, List, Optional
import
numpy
as
np
import
numpy
as
np
import
requests
import
requests
import
torch
import
torch.nn.functional
as
F
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.lang.backend.openai
import
OpenAI
from
sglang.lang.backend.openai
import
OpenAI
...
@@ -492,3 +494,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
...
@@ -492,3 +494,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
print
(
f
"Fail. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s"
)
print
(
f
"Fail. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s"
)
return
0
if
success
else
-
1
return
0
if
success
else
-
1
def
get_similarities
(
vec1
,
vec2
):
return
F
.
cosine_similarity
(
torch
.
tensor
(
vec1
),
torch
.
tensor
(
vec2
),
dim
=
0
)
test/srt/models/test_embedding_models.py
0 → 100644
View file @
e040a245
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import
unittest
import
torch
from
sglang.test.runners
import
DEFAULT_PROMPTS
,
HFRunner
,
SRTRunner
from
sglang.test.test_utils
import
get_similarities
MODELS
=
[(
"intfloat/e5-mistral-7b-instruct"
,
1
)]
TORCH_DTYPES
=
[
torch
.
float16
]
class
TestEmbeddingModels
(
unittest
.
TestCase
):
def
assert_close_prefill_logits
(
self
,
prompts
,
model_path
,
tp_size
,
torch_dtype
,
)
->
None
:
with
HFRunner
(
model_path
,
torch_dtype
=
torch_dtype
,
is_generation_model
=
False
)
as
hf_runner
:
hf_outputs
=
hf_runner
.
forward
(
prompts
)
with
SRTRunner
(
model_path
,
tp_size
=
tp_size
,
torch_dtype
=
torch_dtype
,
is_generation_model
=
False
,
)
as
srt_runner
:
srt_outputs
=
srt_runner
.
forward
(
prompts
)
for
i
in
range
(
len
(
prompts
)):
hf_logits
=
torch
.
Tensor
(
hf_outputs
.
embed_logits
[
i
])
srt_logits
=
torch
.
Tensor
(
srt_outputs
.
embed_logits
[
i
])
similarities
=
torch
.
tensor
(
get_similarities
(
hf_logits
,
srt_logits
))
tolerance
=
1e-2
assert
torch
.
all
(
abs
(
similarities
-
1
)
<
tolerance
),
f
"embeddings not all close"
def
test_prefill_logits
(
self
):
for
model
,
tp_size
in
MODELS
:
for
torch_dtype
in
TORCH_DTYPES
:
self
.
assert_close_prefill_logits
(
DEFAULT_PROMPTS
,
model
,
tp_size
,
torch_dtype
)
if
__name__
==
"__main__"
:
unittest
.
main
(
warnings
=
"ignore"
)
test/srt/models/test_
causal
_models.py
→
test/srt/models/test_
generation
_models.py
View file @
e040a245
...
@@ -3,7 +3,9 @@ Copyright 2023-2024 SGLang Team
...
@@ -3,7 +3,9 @@ Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -33,7 +35,7 @@ class TestCausalModels(unittest.TestCase):
...
@@ -33,7 +35,7 @@ class TestCausalModels(unittest.TestCase):
torch_dtype
,
torch_dtype
,
)
->
None
:
)
->
None
:
with
HFRunner
(
with
HFRunner
(
model_path
,
torch_dtype
=
torch_dtype
,
is_
embedding
_model
=
Fals
e
model_path
,
torch_dtype
=
torch_dtype
,
is_
generation
_model
=
Tru
e
)
as
hf_runner
:
)
as
hf_runner
:
hf_outputs
=
hf_runner
.
forward
(
prompts
)
hf_outputs
=
hf_runner
.
forward
(
prompts
)
...
@@ -41,7 +43,7 @@ class TestCausalModels(unittest.TestCase):
...
@@ -41,7 +43,7 @@ class TestCausalModels(unittest.TestCase):
model_path
,
model_path
,
tp_size
=
tp_size
,
tp_size
=
tp_size
,
torch_dtype
=
torch_dtype
,
torch_dtype
=
torch_dtype
,
is_
embedding
_model
=
Fals
e
,
is_
generation
_model
=
Tru
e
,
)
as
srt_runner
:
)
as
srt_runner
:
srt_outputs
=
srt_runner
.
forward
(
prompts
)
srt_outputs
=
srt_runner
.
forward
(
prompts
)
...
...
test/srt/run_suite.py
View file @
e040a245
...
@@ -10,7 +10,8 @@ suites = {
...
@@ -10,7 +10,8 @@ suites = {
"test_vision_openai_server.py"
,
"test_vision_openai_server.py"
,
"test_chunked_prefill.py"
,
"test_chunked_prefill.py"
,
"test_torch_compile.py"
,
"test_torch_compile.py"
,
"models/test_causal_models.py"
,
"models/test_generation_models.py"
,
"models/test_embedding_models.py"
,
"sampling/penaltylib"
,
"sampling/penaltylib"
,
],
],
"sampling/penaltylib"
:
glob
.
glob
(
"sampling/penaltylib"
:
glob
.
glob
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment